コード例 #1
0
    def test_initialize_model_parallel_with_virtual_and_split(self) -> None:
        if self.world_size < 4:
            self.skipTest("requires >= 4 GPUs")
        self.assertFalse(parallel_state.model_parallel_is_initialized())

        tensor_model_parallel_world_size = 1 + int(self.world_size > 4)
        pipeline_model_parallel_world_size = (self.world_size //
                                              tensor_model_parallel_world_size)
        virtual_pipeline_model_parallel_world_size = 2
        pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2

        parallel_state.initialize_model_parallel(
            tensor_model_parallel_size_=tensor_model_parallel_world_size,
            pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
            virtual_pipeline_model_parallel_size_=
            virtual_pipeline_model_parallel_world_size,
            pipeline_model_parallel_split_rank_=
            pipeline_model_parallel_split_rank,
        )
        self.assertEqual(
            calc_expected_tensor_model_paralell_rank(
                self.rank, tensor_model_parallel_world_size),
            parallel_state.get_tensor_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_world_size,
            parallel_state.get_pipeline_model_parallel_world_size(),
        )
        self.assertEqual(
            virtual_pipeline_model_parallel_world_size,
            parallel_state.get_virtual_pipeline_model_parallel_world_size(),
        )

        expected_pipeline_rank = (self.rank -
                                  (self.rank % tensor_model_parallel_world_size
                                   )) % pipeline_model_parallel_world_size
        self.assertEqual(
            expected_pipeline_rank,
            parallel_state.get_pipeline_model_parallel_rank(),
        )
        # virtual pipeline model parallel rank is lazily set, i.e., right after the call of
        # `initialize_model_parallel`, it's set to 0.
        self.assertEqual(
            0,
            parallel_state.get_virtual_pipeline_model_parallel_rank(),
        )
        self.assertEqual(
            pipeline_model_parallel_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank(),
        )

        fake_split_rank = 77
        parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank)
        self.assertEqual(
            fake_split_rank,
            parallel_state.get_pipeline_model_parallel_split_rank())

        parallel_state.destroy_model_parallel()
コード例 #2
0
def convert(local_rank, rank, world_size, args):

    app_state = AppState()
    app_state.data_parallel_rank = 0
    num_nodes = world_size // args.gpus_per_node
    if args.bcp:
        trainer = Trainer(devices=args.gpus_per_node,
                          num_nodes=num_nodes,
                          accelerator='gpu',
                          plugins=[TorchElasticEnvironment()])
    else:
        trainer = Trainer(devices=args.gpus_per_node,
                          num_nodes=num_nodes,
                          accelerator='gpu')

    app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
    app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
    app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size

    parallel_state.initialize_model_parallel(
        tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
        pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
    )

    app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank(
    )
    app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank(
    )

    # inject model parallel rank
    checkpoint_path = inject_model_parallel_rank(
        os.path.join(args.checkpoint_folder, args.checkpoint_name))

    logging.info(
        f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}'
    )

    if args.model_type == 'gpt':
        model = MegatronGPTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 'bert':
        model = MegatronBertModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 't5':
        model = MegatronT5Model.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    elif args.model_type == 'nmt':
        model = MegatronNMTModel.load_from_checkpoint(
            checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
    model._save_restore_connector = NLPSaveRestoreConnector()

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    model.save_to(args.nemo_file_path)

    logging.info(f'NeMo model saved to: {args.nemo_file_path}')
コード例 #3
0
def _set_random_seed(seed_):
    """Set random seed for reproducability."""
    if seed_ is not None and seed_ > 0:
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * get_pipeline_model_parallel_rank())
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.device_count() > 0:
            tensor_parallel.model_parallel_cuda_manual_seed(seed)
    else:
        raise ValueError(
            'Seed ({}) should be a positive integer.'.format(seed_))
コード例 #4
0
ファイル: nlp_overrides.py プロジェクト: NVIDIA/NeMo
    def init_model_parallel(self, global_rank: int, world_size: int) -> None:
        """ Initializes Megatron-LM model parallel if using model parallelism.

        Args:
            global_rank (int): the global process index.
            world_size (int): the total number of GPUs, num_nodes * num_devices
            is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
        """
        app_state = AppState()

        # we initialize megatron-lm model parallel and data parallel groups
        # after initializing DDP with PTL.
        if app_state.model_parallel_size is not None:
            # destroy groups in case they have already been created
            # this happens with multiple calls to trainer.test for example
            parallel_state.destroy_model_parallel()
            if torch.distributed.is_initialized():
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=app_state.
                    tensor_model_parallel_size,
                    pipeline_model_parallel_size_=app_state.
                    pipeline_model_parallel_size,
                    pipeline_model_parallel_split_rank_=app_state.
                    pipeline_model_parallel_split_rank,
                )

                # assert that fake tp and pp rank match after model parallel init
                assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank(
                )
                assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank(
                )

                app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group(
                )
                app_state.data_parallel_group = parallel_state.get_data_parallel_group(
                )
                app_state.data_parallel_rank = parallel_state.get_data_parallel_rank(
                )
                app_state.data_parallel_size = parallel_state.get_data_parallel_world_size(
                )
                app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group(
                )
コード例 #5
0
 def init_weights(m):
     rank = parallel_state.get_pipeline_model_parallel_rank()
     if isinstance(m, torch.nn.Linear):
         m.weight.fill_((rank + offset + 1.0) / weight_coeff)
         m.bias.fill_(1.0)
コード例 #6
0
ファイル: transformer.py プロジェクト: AlexGrinch/NeMo
    def __init__(
        self,
        init_method,
        output_layer_init_method,
        num_layers,
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        layer_type=LayerType.encoder,
        self_attn_mask_type=AttnMaskType.padding,
        pre_process=True,
        post_process=True,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        hidden_dropout=0.1,
        use_cpu_initialization=False,
        bias_gelu_fusion=True,
        openai_gelu=False,
        onnx_safe=False,
    ):
        super(ParallelTransformer, self).__init__()

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        self.fp32_residual_connection = fp32_residual_connection
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None

        # Store activation checkpointing flag.
        self.activations_checkpoint_method = activations_checkpoint_method
        self.activations_checkpoint_num_layers = activations_checkpoint_num_layers

        # Number of layers.
        assert (
            num_layers %
            parallel_state.get_pipeline_model_parallel_world_size() == 0
        ), 'num_layers must be divisible by pipeline_model_parallel_size'
        self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size(
        )

        # Transformer layers.
        def build_layer(layer_number):
            return ParallelTransformerLayer(
                init_method=init_method,
                output_layer_init_method=output_layer_init_method,
                layer_number=layer_number,
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                layernorm_epsilon=layernorm_epsilon,
                hidden_dropout=hidden_dropout,
                use_cpu_initialization=use_cpu_initialization,
                bias_gelu_fusion=bias_gelu_fusion,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
            )

        # TODO: get virtual_pipeline_model_parallel_size from apex.mpu
        # if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None:
        #     assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, (
        #         'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size'
        #     )
        #     # Number of layers in each model chunk is the number of layers in the stage,
        #     # divided by the number of model chunks in a stage.
        #     self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
        #     # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
        #     # layers to stages like (each list is a model chunk):
        #     # Stage 0: [0]  [2]  [4]  [6]
        #     # Stage 1: [1]  [3]  [5]  [7]
        #     # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
        #     # layers to stages like (each list is a model chunk):
        #     # Stage 0: [0, 1]  [4, 5]
        #     # Stage 1: [2, 3]  [6, 7]
        #     offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
        #         args.num_layers // args.virtual_pipeline_model_parallel_size
        #     ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
        # else:
        #     # Each stage gets a contiguous set of layers.
        #     offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
        offset = parallel_state.get_pipeline_model_parallel_rank(
        ) * self.num_layers

        self.layers = torch.nn.ModuleList(
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])

        if self.post_process:
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(hidden_size,
                                             eps=layernorm_epsilon)
コード例 #7
0
def build_model(
        model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
        wrap_with_ddp: bool = True,
        virtual_pipeline_model_parallel_size: Optional[int] = None,
        *args,
        **kwargs
) -> List[torch.nn.Module]:
    """Build the model satisfying pipeline model parallel requirements.

    This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
    `model_provider_func`.

    Args:
        model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
        wrap_with_ddp: If :obj:`True`, wrap the instantiated model
            with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
        virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
        *args: arguments for model provider func
        **kwargs: Keyword arguments for model provider func

    Returns:
        a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
        the list has multiple models, otherwise one.
    """
    if (
            parallel_state.get_pipeline_model_parallel_world_size() > 1 and
            virtual_pipeline_model_parallel_size is not None
    ):
        model = []
        for i in range(virtual_pipeline_model_parallel_size):
            cur_args = args
            cur_kwargs = kwargs
            parallel_state.set_virtual_pipeline_model_parallel_rank(i)
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            this_model = model_provider_func(*cur_args, **cur_kwargs)
            model.append(this_model)
    else:
        cur_args = args
        cur_kwargs = kwargs
        pre_process = parallel_state.is_pipeline_first_stage()
        post_process = parallel_state.is_pipeline_last_stage()
        cur_kwargs.update({
            "pre_process": pre_process,
            "post_process": post_process,
        })
        model = model_provider_func(*cur_args, **cur_kwargs)

    if not isinstance(model, list):
        model = [model]

    # Set tensor model parallel attributes if not set.
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
    for model_module in model:
        for param in model_module.parameters():
            set_defaults_if_not_set_tensor_model_parallel_attributes(param)

    # Print number of parameters.
    if parallel_state.get_data_parallel_rank() == 0:
        msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
            parallel_state.get_tensor_model_parallel_rank(),
            parallel_state.get_pipeline_model_parallel_rank(),
            sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
        )
        print(msg, flush=True)

    # GPU allocation.
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())

    if wrap_with_ddp:
        i = torch.cuda.current_device()
        model = [
            torch.nn.parallel.distributed.DistributedDataParallel(
                model_module,
                device_ids=[i],
                output_device=i,
                process_group=parallel_state.get_data_parallel_group(),
            )
            for model_module in model
        ]
    return model
コード例 #8
0
def forward_backward_pipelining_without_interleaving(
    forward_step_func: FwdStepFunc,
    batch: Batch,
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
):
    """Run non-interleaved 1F1B schedule, with communication between pipeline stages.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown if not forward_only

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor. Required for P2P communication.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    # timers = get_timers()

    model = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = (
        parallel_state.get_pipeline_model_parallel_world_size() -
        parallel_state.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = num_microbatches - num_warmup_microbatches

    _logger.info(f"num_microbatches: {num_microbatches}, "
                 f"num_warmup_microbatches: {num_warmup_microbatches}, "
                 f"num_microbatches_remaining: {num_microbatches_remaining}")

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
    losses_reduced = []

    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    _logger.info("Warmup")
    for i in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
        _logger.debug("receive fwd")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)
        cur_microbatch = get_kth_microbatch(batch, i)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        _logger.debug("send fwd")
        p2p_communication.send_forward(output_tensor,
                                       tensor_shape=tensor_shape)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        _logger.debug("recv_forward before steady state start")
        input_tensor = p2p_communication.recv_forward(
            tensor_shape=tensor_shape)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for i in range(num_microbatches_remaining):
        _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
        last_iteration = i == (num_microbatches_remaining - 1)

        cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches)
        output_tensor = forward_step(forward_step_func, cur_microbatch, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            _logger.debug("send fwd")
            p2p_communication.send_forward(output_tensor,
                                           tensor_shape=tensor_shape)

            if not last_iteration:
                _logger.debug("receive fwd (last iteration)")
                input_tensor = p2p_communication.recv_forward(
                    tensor_shape=tensor_shape)

        else:
            _logger.debug("send fwd & receive bwd")
            output_tensor_grad = p2p_communication.send_forward_recv_backward(
                output_tensor, tensor_shape=tensor_shape)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

            # Pop input_tensor and output_tensor from the start of the list for the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
                _logger.debug("send bwd")
                p2p_communication.send_backward(input_tensor_grad,
                                                tensor_shape=tensor_shape)
            else:
                _logger.debug("send bwd and receive fwd")
                input_tensor = p2p_communication.send_backward_recv_forward(
                    input_tensor_grad, tensor_shape=tensor_shape)
    ###################################################################################################################
    # Run cooldown backward passes.
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        for i in range(num_warmup_microbatches):
            _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            _logger.debug("receive bwd")
            output_tensor_grad = p2p_communication.recv_backward(
                tensor_shape=tensor_shape)

            input_tensor_grad = backward_step(input_tensor, output_tensor,
                                              output_tensor_grad)

            _logger.debug("send bwd")
            p2p_communication.send_backward(input_tensor_grad,
                                            tensor_shape=tensor_shape)

    return losses_reduced
コード例 #9
0
def convert(local_rank, rank, world_size, args):

    app_state = AppState()
    app_state.data_parallel_rank = 0
    tensor_model_parallel_size = args.tensor_model_parallel_size
    num_nodes = world_size // args.gpus_per_node
    pipeline_model_parallel_size = world_size // args.tensor_model_parallel_size
    assert args.pipeline_model_parallel_size == pipeline_model_parallel_size

    trainer = Trainer(devices=args.gpus_per_node,
                      accelerator='gpu',
                      num_nodes=num_nodes)

    app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
    app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
    app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size

    parallel_state.initialize_model_parallel(
        tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
        pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
    )

    app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank(
    )
    app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank(
    )

    pipeline_rank = rank // tensor_model_parallel_size
    tensor_rank = app_state.tensor_model_parallel_rank
    assert pipeline_rank == app_state.pipeline_model_parallel_rank

    if tensor_model_parallel_size is not None and tensor_model_parallel_size > 1 and pipeline_model_parallel_size == 1:
        # inject model parallel rank
        checkpoint_path = os.path.join(args.checkpoint_folder,
                                       f'mp_rank_{tensor_rank:02d}',
                                       args.checkpoint_name)
    elif tensor_model_parallel_size is not None and pipeline_model_parallel_size > 1:
        checkpoint_path = os.path.join(
            args.checkpoint_folder,
            f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}',
            args.checkpoint_name)
    else:
        checkpoint_path = os.path.join(args.checkpoint_folder,
                                       args.checkpoint_name)
    logging.info(f"loading checkpoint {checkpoint_path}")

    if args.model_type == 'gpt':
        ## this dictionary is used to rename the model parameters
        name_translate = {}
        name_translate['transformer'] = 'encoder'
        name_translate['.attention.'] = '.self_attention.'
        # nemo megatron doesn't have _for_head key
        name_translate['word_embeddings_for_head'] = 'word_embeddings'
        checkpoint, consumed, steps, version = load_from_checkpoint(
            MegatronGPTModel,
            checkpoint_path,
            hparams_file=args.hparams_file,
            trainer=trainer,
            translator=name_translate,
            strict=False,
        )
    elif args.model_type == 'bert':
        ## this dictionary is used to rename the model parameters
        name_translate = {}
        name_translate['transformer'] = 'encoder'
        name_translate['.attention.'] = '.self_attention.'
        # nemo megatron doesn't have _for_head key
        name_translate['word_embeddings_for_head'] = 'word_embeddings'
        checkpoint, consumed, steps, version = load_from_checkpoint(
            MegatronBertModel,
            checkpoint_path,
            hparams_file=args.hparams_file,
            trainer=trainer,
            translator=name_translate,
            strict=False,
        )
    else:
        raise NotImplemented("{} is not supported".format(args.model_type))

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    if args.output_ckpt_file_path:
        filepath = args.output_ckpt_file_path
        base_dir = pathlib.Path(filepath).parent
        filename_str = pathlib.Path(filepath).name
        suffix = '.ckpt'
        content = {}
        if consumed is not None:
            content['consumed'] = consumed
        else:
            content['consumed'] = 0
        if steps is not None:
            content['steps'] = steps
        else:
            content['steps'] = 0
        filename = filename_str.format(**content) + suffix
        checkpoint_path_output = inject_model_parallel_rank(
            os.path.join(base_dir, filename))
        trainer.accelerator.training_type_plugin.checkpoint_io.save_checkpoint(
            checkpoint, checkpoint_path_output)
        logging.info(
            f'NeMo model checkpoint files saved to: {args.output_ckpt_file_path}'
        )

    if args.nemo_file_path:
        if args.model_type == 'gpt':
            model = load_model(MegatronGPTModel,
                               checkpoint,
                               strict=False,
                               trainer=trainer)
        elif args.model_type == 'bert':
            model = load_model(MegatronBertModel,
                               checkpoint,
                               strict=False,
                               trainer=trainer)
        else:
            raise NotImplemented("{} is not supported".format(args.model_type))

        # verify tensor parallel rank id and pipeline parallel rank id matches
        assert app_state.data_parallel_size == 1
        assert app_state.tensor_model_parallel_size == tensor_model_parallel_size
        assert app_state.tensor_model_parallel_rank == tensor_rank
        assert app_state.pipeline_model_parallel_size == pipeline_model_parallel_size
        assert app_state.pipeline_model_parallel_rank == pipeline_rank
        model._save_restore_connector = NLPSaveRestoreConnector()
        model.save_to(args.nemo_file_path)
        logging.info(f'NeMo model saved to: {args.nemo_file_path}')
コード例 #10
0
def _forward_backward_pipelining_with_interleaving(
    forward_step_func: FwdStepFunc,
    batch: List[Optional[Batch]],
    model: List[torch.nn.Module],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
    dtype: Optional[torch.dtype] = None,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    disable_autocast: bool = False,
    deallocate_pipeline_outputs: bool = False,
    **kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
    """Run interleaved 1F1B schedule with communication between pipeline stages as needed.

    This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
    This means that model is split into model chunks.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown
    Note that if `forward_only` this scheduling consists of only warmup phase.

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor.
        dtype: dtype used in p2p communication. If ``None`` (default value),
            torch.float32 will be used even if ``autocast`` is enabled.
        grad_scaler:
        disable_autocast:
        deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
            each pipeline stage. Experimental.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    if not isinstance(model, list):
        raise RuntimeError("`model` must be a list of `nn.Module`'s'")

    num_model_chunks: int = len(model)
    input_tensors: List[List[Union[None, torch.Tensor]]] = [
        [] for _ in range(num_model_chunks)
    ]
    output_tensors: List[List[Union[None, torch.Tensor]]] = [
        [] for _ in range(num_model_chunks)
    ]
    curr_iters: List[int] = [0 for _ in range(num_model_chunks)]
    losses_reduced: List[Union[None, torch.Tensor]] = []
    if not forward_only:
        output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [
            [] for _ in range(num_model_chunks)
        ]

    pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size(
    )
    pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank(
    )

    # Compute number of warmup and remaining microbatches.
    num_microbatches: int = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches: bool = False
    if forward_only:
        num_warmup_microbatches: int = num_microbatches
    else:
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = (pipeline_parallel_size -
                                       pipeline_parallel_rank - 1) * 2
            num_warmup_microbatches += (num_model_chunks -
                                        1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
    num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches

    _logger.info(f"num_microbatches: {num_microbatches}, "
                 f"num_warmup_microbatches: {num_warmup_microbatches}, "
                 f"num_microbatches_remaining: {num_microbatches_remaining}")

    ###################################################################################################################
    # Helper function definitions.
    ###################################################################################################################
    def get_model_chunk_id(microbatch_id: int, forward: bool) -> int:
        """Helper function to get the model chunk ID given the iteration number."""
        pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size(
        )
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size *
                                                  num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
        if not forward:
            model_chunk_id = num_model_chunks - model_chunk_id - 1
        return model_chunk_id

    def forward_step_helper(microbatch_id: int,
                            curr_iters: List[int]) -> torch.Tensor:
        """Helper method to run forward step with model split into chunks

        (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
        """
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        # forward step
        if (parallel_state.is_pipeline_first_stage()
                and len(input_tensors[model_chunk_id]) == len(
                    output_tensors[model_chunk_id])):
            input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
        output_tensor = forward_step(
            forward_step_func,
            get_kth_microbatch(batch, curr_iters[model_chunk_id]),
            model[model_chunk_id],
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        curr_iters[model_chunk_id] += 1
        output_tensors[model_chunk_id].append(output_tensor)

        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

        return output_tensor

    def backward_step_helper(microbatch_id: int) -> torch.Tensor:
        """Helper method to run backward step with model split into chunks

        (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
        """
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
        model_type = get_model_type(model[model_chunk_id])
        parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if parallel_state.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = backward_step(
            input_tensor,
            output_tensor,
            output_tensor_grad,
            model_type=model_type,
            grad_scaler=grad_scaler,
            deallocate_pipeline_outputs=deallocate_pipeline_outputs)

        return input_tensor_grad

    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    parallel_state.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(
        p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
    _logger.info("Warmup phase")
    for k in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
        output_tensor = forward_step_helper(k, curr_iters)

        # Determine if tensor should be received from previous stage.
        next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
        recv_prev = True
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
        _logger.debug(
            f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}"
        )

        # Don't send tensor downstream if on last stage.
        if parallel_state.is_pipeline_last_stage():
            _logger.debug("Pipeline last stage, not sending tensor downstream")
            output_tensor = None

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
        if k == (num_warmup_microbatches -
                 1) and not forward_only and not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            _logger.debug("send fwd&bwd and receive fwd&bwd")
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p_communication.send_forward_backward_recv_forward_backward(
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
                recv_next=recv_next,
                tensor_shape=tensor_shape,
                dtype=dtype,
            )
            output_tensor_grads[num_model_chunks -
                                1].append(output_tensor_grad)
        else:
            _logger.debug("send fwd and receive fwd")
            input_tensor = p2p_communication.send_forward_recv_forward(
                output_tensor,
                recv_prev=recv_prev,
                tensor_shape=tensor_shape,
                dtype=dtype)
        free_output_tensor(output_tensor, deallocate_pipeline_outputs)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for k in range(num_microbatches_remaining):
        # Forward pass.
        _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}")
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k, curr_iters)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        parallel_state.set_virtual_pipeline_model_parallel_rank(
            forward_model_chunk_id)
        if parallel_state.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        parallel_state.set_virtual_pipeline_model_parallel_rank(
            backward_model_chunk_id)
        _logger.debug(
            f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}"
        )
        if parallel_state.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)

        recv_next = True
        if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)

        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        _logger.debug("send fwd&bwd and receive fwd&bwd")
        (
            input_tensor,
            output_tensor_grad,
        ) = p2p_communication.send_forward_backward_recv_forward_backward(
            output_tensor,
            input_tensor_grad,
            recv_prev=recv_prev,
            recv_next=recv_next,
            tensor_shape=tensor_shape,
            dtype=dtype,
        )
        free_output_tensor(output_tensor, deallocate_pipeline_outputs)

        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)

    ###################################################################################################################
    # Run cooldown backward passes (flush out pipeline).
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks - 1].append(
                p2p_communication.recv_backward(tensor_shape=tensor_shape,
                                                dtype=dtype))
        for k in range(num_microbatches_remaining, num_microbatches):
            _logger.debug(
                f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})"
            )
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k + 1,
                                                              forward=False)
            recv_next = True
            if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad,
                    recv_next=recv_next,
                    tensor_shape=tensor_shape,
                    dtype=dtype))

    return losses_reduced
コード例 #11
0
def build_model(
    model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module],
    wrap_with_ddp: bool = True,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    model_type: ModelType = ModelType.encoder_or_decoder,
    *args: Any,
    **kwargs: Any,
) -> List[torch.nn.Module]:
    """Build the model satisfying pipeline model parallel requirements.

    This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
    `model_provider_func`.

    Args:
        model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
        wrap_with_ddp: If :obj:`True`, wrap the instantiated model
            with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
        virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
        model_type:
        *args: arguments for model provider func
        **kwargs: Keyword arguments for model provider func

    Returns:
        a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
        the list has multiple models, otherwise one.
    """
    if (parallel_state.get_pipeline_model_parallel_world_size() > 1
            and virtual_pipeline_model_parallel_size is not None):
        model = []
        for i in range(virtual_pipeline_model_parallel_size):
            cur_args = args
            cur_kwargs = kwargs
            parallel_state.set_virtual_pipeline_model_parallel_rank(i)
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            this_model = model_provider_func(*cur_args, **cur_kwargs)
            model.append(this_model)
    else:
        cur_args = args
        cur_kwargs = kwargs
        if model_type == ModelType.encoder_or_decoder:
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
            })
            model = model_provider_func(*cur_args, **cur_kwargs)
        elif model_type == ModelType.encoder_and_decoder:
            pre_process = parallel_state.is_pipeline_first_stage()
            post_process = parallel_state.is_pipeline_last_stage()
            # `add_encoder` & `add_decoder` logic.
            add_encoder, add_decoder = True, True
            if parallel_state.get_pipeline_model_parallel_world_size() > 1:
                split_rank = parallel_state.get_pipeline_model_parallel_split_rank(
                )
                if split_rank is None:
                    raise RuntimeError(
                        "Split rank needs to be specified for model with both encoder and decoder."
                    )
                rank = parallel_state.get_pipeline_model_parallel_rank()
                world_size = parallel_state.get_pipeline_model_parallel_world_size(
                )
                pre_process = rank == 0 or rank == split_rank
                post_process = rank == (split_rank -
                                        1) or rank == (world_size - 1)
                add_encoder = parallel_state.is_pipeline_stage_before_split()
                add_decoder = parallel_state.is_pipeline_stage_after_split()
            cur_kwargs.update({
                "pre_process": pre_process,
                "post_process": post_process,
                "add_encoder": add_encoder,
                "add_decoder": add_decoder,
            })
            model = model_provider_func(*cur_args, **cur_kwargs)
        model.model_type = model_type

    if not isinstance(model, list):
        model = [model]

    # Set tensor model parallel attributes if not set.
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
    for model_module in model:
        for param in model_module.parameters():
            set_defaults_if_not_set_tensor_model_parallel_attributes(param)

    # Print number of parameters.
    if (parallel_state.model_parallel_is_initialized()
            and parallel_state.get_data_parallel_rank() == 0):
        msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format(
            parallel_state.get_tensor_model_parallel_rank(),
            parallel_state.get_pipeline_model_parallel_rank(),
            _calc_number_of_params(model),
        )
        print(msg, flush=True)

    # GPU allocation.
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())

    if wrap_with_ddp:
        i = torch.cuda.current_device()
        model = [
            torch.nn.parallel.distributed.DistributedDataParallel(
                model_module,
                device_ids=[i],
                output_device=i,
                process_group=parallel_state.get_data_parallel_group(),
            ) for model_module in model
        ]
    return model
コード例 #12
0
def forward_backward_pipelining_without_interleaving(
    forward_step_func: FwdStepFunc,
    batch: Optional[Batch],
    model: Union[torch.nn.Module, List[torch.nn.Module]],
    *,
    forward_only: bool,
    tensor_shape: Optional[Union[List[int], torch.Size]] = None,
    decoder_sequence_length: Optional[int] = None,
    dtype: Optional[torch.dtype] = None,
    grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
    disable_autocast: bool = False,
    deallocate_pipeline_outputs: bool = False,
    **kwawrgs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
    """Run non-interleaved 1F1B schedule, with communication between pipeline stages.

    This pipeline parallel scheduling consists of three steps:
        1. warmup
        2. 1F1B a.k.a. steady state
        3. cooldown if not forward_only

    Args:
        forward_step_func: A function which takes a minibatch and model as its arguments and
            returns model's forward output and the loss function.
            The loss function is supposed to take one `torch.Tensor` and
            return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
        batch: A minibatch, i.e., a list of `torch.Tensor`'s.
        model: A `torch.nn.Module` or a list of `torch.nn.Module`.

    Keyword args:
        forward_only:
        tensor_shape: Shape of tensor. Required for P2P communication.
        dtype: dtype used in p2p communication. If ``None`` (default value),
            torch.float32 will be used even if ``autocast`` is enabled.
        grad_scaler:
        disable_autocast:
        deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
            each pipeline stage. Experimental.

    Returns:
        a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
    """
    # timers = get_timers()

    model: List[torch.nn.Module] = listify_model(model)
    if len(model) != 1:
        msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
        raise RuntimeError(msg)
    model: torch.nn.Module = model[0]

    # Compute number of warmup microbatches.
    num_microbatches: int = get_num_microbatches()
    num_warmup_microbatches: int = (
        parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1
    )
    num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches

    model_type = get_model_type(model)
    rank: int = parallel_state.get_pipeline_model_parallel_rank()
    recv_tensor_shapes: List[List[int]] = get_tensor_shapes(
        rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
    )
    send_tensor_shapes: List[List[int]] = get_tensor_shapes(
        rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length
    )

    _logger.info(
        f"num_microbatches: {num_microbatches}, "
        f"num_warmup_microbatches: {num_warmup_microbatches}, "
        f"num_microbatches_remaining: {num_microbatches_remaining}"
    )

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors: List[Union[None, torch.Tensor]] = []
    output_tensors: List[Union[None, torch.Tensor]] = []
    losses_reduced: List[Union[None, torch.Tensor]] = []
    ###################################################################################################################
    # Run warmup forward passes.
    ###################################################################################################################
    _logger.info("Warmup")
    for i in range(num_warmup_microbatches):
        _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
        _logger.debug("receive fwd")
        input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
        cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
        output_tensor = forward_step(
            forward_step_func,
            cur_microbatch,
            model,
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        _logger.debug("send fwd")
        send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, deallocate_pipeline_outputs)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        _logger.debug("recv_forward before steady state start")
        input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)

    ###################################################################################################################
    # Run 1F1B in steady state.
    ###################################################################################################################
    _logger.info("Steady phase")
    for i in range(num_microbatches_remaining):
        _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}")
        last_iteration: bool = i == (num_microbatches_remaining - 1)

        cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches)
        output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
            forward_step_func,
            cur_microbatch,
            model,
            input_tensor,
            losses_reduced,
            dtype,
            disable_autocast,
        )
        if forward_only:
            _logger.debug("send fwd")
            send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

            if not last_iteration:
                _logger.debug("receive fwd (last iteration)")
                input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)

        else:
            _logger.debug("send fwd & receive bwd")
            output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
            free_output_tensor(output_tensor, deallocate_pipeline_outputs)

            # Pop input_tensor and output_tensor from the start of the list for the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = backward_step(
                input_tensor,
                output_tensor,
                output_tensor_grad,
                model_type=model_type,
                grad_scaler=grad_scaler,
                deallocate_pipeline_outputs=deallocate_pipeline_outputs,
            )

            if last_iteration:
                input_tensor = None
                _logger.debug("send bwd")
                send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
            else:
                _logger.debug("send bwd and receive fwd")
                input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
    ###################################################################################################################
    # Run cooldown backward passes.
    ###################################################################################################################
    _logger.info("Cooldown phase")
    if not forward_only:
        for i in range(num_warmup_microbatches):
            _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}")
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            _logger.debug("receive bwd")
            output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)

            input_tensor_grad = backward_step(
                input_tensor,
                output_tensor,
                output_tensor_grad,
                model_type=model_type,
                grad_scaler=grad_scaler,
                deallocate_pipeline_outputs=deallocate_pipeline_outputs,
            )

            _logger.debug("send bwd")
            send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)

    return losses_reduced