import torch from apex.transformer.pipeline_parallel.utils import listify_model from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_kth_microbatch from apex.transformer.pipeline_parallel.utils import get_model_type from apex.transformer.pipeline_parallel.schedules.common import Batch from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.log_util import get_transformer_logger _all__ = ["forward_backward_no_pipelining"] _logger = get_transformer_logger(__name__) @contextmanager def placeholder_handler(): try: yield finally: pass def forward_backward_no_pipelining( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *,
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization from apex.transformer.pipeline_parallel.schedules.common import build_model from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import _forward_backward_pipelining_with_interleaving from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving from apex.transformer.pipeline_parallel.utils import average_losses_across_data_parallel_group from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator from apex.transformer.pipeline_parallel.utils import update_num_microbatches from apex.transformer.testing import global_vars from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import initialize_distributed from apex.transformer.testing.commons import print_separator from apex.transformer.log_util import get_transformer_logger, set_logging_level # set_logging_level("INFO") _logger = get_transformer_logger("pipeline_parallel_test") global_vars.set_global_variables() batch_size, micro_batch_size = None, None hidden_size = 16 fwd_bwd_functions = { "no_pipelining": forward_backward_no_pipelining, "no_interleaving": forward_backward_pipelining_without_interleaving, "interleaving": _forward_backward_pipelining_with_interleaving, } # note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes. class MyLayer(nn.Module): def __init__(self, pre_process: bool, post_process: bool): super().__init__()