def test_initialize_model_parallel_with_virtual_and_split(self) -> None: if self.world_size < 4: self.skipTest("requires >= 4 GPUs") self.assertFalse(parallel_state.model_parallel_is_initialized()) tensor_model_parallel_world_size = 1 + int(self.world_size > 4) pipeline_model_parallel_world_size = (self.world_size // tensor_model_parallel_world_size) virtual_pipeline_model_parallel_world_size = 2 pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_=pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_world_size, pipeline_model_parallel_split_rank_= pipeline_model_parallel_split_rank, ) self.assertEqual( calc_expected_tensor_model_paralell_rank( self.rank, tensor_model_parallel_world_size), parallel_state.get_tensor_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_world_size, parallel_state.get_pipeline_model_parallel_world_size(), ) self.assertEqual( virtual_pipeline_model_parallel_world_size, parallel_state.get_virtual_pipeline_model_parallel_world_size(), ) expected_pipeline_rank = (self.rank - (self.rank % tensor_model_parallel_world_size )) % pipeline_model_parallel_world_size self.assertEqual( expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(), ) # virtual pipeline model parallel rank is lazily set, i.e., right after the call of # `initialize_model_parallel`, it's set to 0. self.assertEqual( 0, parallel_state.get_virtual_pipeline_model_parallel_rank(), ) self.assertEqual( pipeline_model_parallel_split_rank, parallel_state.get_pipeline_model_parallel_split_rank(), ) fake_split_rank = 77 parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank) self.assertEqual( fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank()) parallel_state.destroy_model_parallel()
def convert(local_rank, rank, world_size, args): app_state = AppState() app_state.data_parallel_rank = 0 num_nodes = world_size // args.gpus_per_node if args.bcp: trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=[TorchElasticEnvironment()]) else: trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu') app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size app_state.tensor_model_parallel_size = args.tensor_model_parallel_size app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state.tensor_model_parallel_size, pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, ) app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) # inject model parallel rank checkpoint_path = inject_model_parallel_rank( os.path.join(args.checkpoint_folder, args.checkpoint_name)) logging.info( f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}' ) if args.model_type == 'gpt': model = MegatronGPTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 'bert': model = MegatronBertModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 't5': model = MegatronT5Model.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) elif args.model_type == 'nmt': model = MegatronNMTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) model._save_restore_connector = NLPSaveRestoreConnector() if torch.distributed.is_initialized(): torch.distributed.barrier() model.save_to(args.nemo_file_path) logging.info(f'NeMo model saved to: {args.nemo_file_path}')
def _set_random_seed(seed_): """Set random seed for reproducability.""" if seed_ is not None and seed_ > 0: # Ensure that different pipeline MP stages get different seeds. seed = seed_ + (100 * get_pipeline_model_parallel_rank()) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.device_count() > 0: tensor_parallel.model_parallel_cuda_manual_seed(seed) else: raise ValueError( 'Seed ({}) should be a positive integer.'.format(seed_))
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: # destroy groups in case they have already been created # this happens with multiple calls to trainer.test for example parallel_state.destroy_model_parallel() if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state. tensor_model_parallel_size, pipeline_model_parallel_size_=app_state. pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=app_state. pipeline_model_parallel_split_rank, ) # assert that fake tp and pp rank match after model parallel init assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank( ) assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group( ) app_state.data_parallel_group = parallel_state.get_data_parallel_group( ) app_state.data_parallel_rank = parallel_state.get_data_parallel_rank( ) app_state.data_parallel_size = parallel_state.get_data_parallel_world_size( ) app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group( )
def init_weights(m): rank = parallel_state.get_pipeline_model_parallel_rank() if isinstance(m, torch.nn.Linear): m.weight.fill_((rank + offset + 1.0) / weight_coeff) m.bias.fill_(1.0)
def __init__( self, init_method, output_layer_init_method, num_layers, hidden_size, ffn_hidden_size, num_attention_heads, apply_query_key_layer_scaling=True, kv_channels=None, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, pre_process=True, post_process=True, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, hidden_dropout=0.1, use_cpu_initialization=False, bias_gelu_fusion=True, openai_gelu=False, onnx_safe=False, ): super(ParallelTransformer, self).__init__() if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads self.fp32_residual_connection = fp32_residual_connection self.pre_process = pre_process self.post_process = post_process self.input_tensor = None # Store activation checkpointing flag. self.activations_checkpoint_method = activations_checkpoint_method self.activations_checkpoint_num_layers = activations_checkpoint_num_layers # Number of layers. assert ( num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0 ), 'num_layers must be divisible by pipeline_model_parallel_size' self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size( ) # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( init_method=init_method, output_layer_init_method=output_layer_init_method, layer_number=layer_number, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, precision=precision, fp32_residual_connection=fp32_residual_connection, layernorm_epsilon=layernorm_epsilon, hidden_dropout=hidden_dropout, use_cpu_initialization=use_cpu_initialization, bias_gelu_fusion=bias_gelu_fusion, openai_gelu=openai_gelu, onnx_safe=onnx_safe, ) # TODO: get virtual_pipeline_model_parallel_size from apex.mpu # if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None: # assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, ( # 'num_layers_per_stage must be divisible by ' 'virtual_pipeline_model_parallel_size' # ) # # Number of layers in each model chunk is the number of layers in the stage, # # divided by the number of model chunks in a stage. # self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size # # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of # # layers to stages like (each list is a model chunk): # # Stage 0: [0] [2] [4] [6] # # Stage 1: [1] [3] [5] [7] # # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of # # layers to stages like (each list is a model chunk): # # Stage 0: [0, 1] [4, 5] # # Stage 1: [2, 3] [6, 7] # offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( # args.num_layers // args.virtual_pipeline_model_parallel_size # ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) # else: # # Each stage gets a contiguous set of layers. # offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers offset = parallel_state.get_pipeline_model_parallel_rank( ) * self.num_layers self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) if self.post_process: # Final layer norm before output. self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, *args, **kwargs ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if ( parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None ): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if parallel_state.get_data_parallel_rank() == 0: msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Batch, model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, ): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model = model[0] # Compute number of warmup microbatches. num_microbatches = get_num_microbatches() num_warmup_microbatches = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") # Input, output tensors only need to be saved when doing backward passes input_tensors = None output_tensors = None if not forward_only: input_tensors = [] output_tensors = [] losses_reduced = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) cur_microbatch = get_kth_microbatch(batch, i) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration = i == (num_microbatches_remaining - 1) cur_microbatch = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) if forward_only: _logger.debug("send fwd") p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = p2p_communication.recv_forward( tensor_shape=tensor_shape) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor, tensor_shape=tensor_shape) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) else: _logger.debug("send bwd and receive fwd") input_tensor = p2p_communication.send_backward_recv_forward( input_tensor_grad, tensor_shape=tensor_shape) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = p2p_communication.recv_backward( tensor_shape=tensor_shape) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad) _logger.debug("send bwd") p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) return losses_reduced
def convert(local_rank, rank, world_size, args): app_state = AppState() app_state.data_parallel_rank = 0 tensor_model_parallel_size = args.tensor_model_parallel_size num_nodes = world_size // args.gpus_per_node pipeline_model_parallel_size = world_size // args.tensor_model_parallel_size assert args.pipeline_model_parallel_size == pipeline_model_parallel_size trainer = Trainer(devices=args.gpus_per_node, accelerator='gpu', num_nodes=num_nodes) app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size app_state.tensor_model_parallel_size = args.tensor_model_parallel_size app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size parallel_state.initialize_model_parallel( tensor_model_parallel_size_=app_state.tensor_model_parallel_size, pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, ) app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank( ) app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank( ) pipeline_rank = rank // tensor_model_parallel_size tensor_rank = app_state.tensor_model_parallel_rank assert pipeline_rank == app_state.pipeline_model_parallel_rank if tensor_model_parallel_size is not None and tensor_model_parallel_size > 1 and pipeline_model_parallel_size == 1: # inject model parallel rank checkpoint_path = os.path.join(args.checkpoint_folder, f'mp_rank_{tensor_rank:02d}', args.checkpoint_name) elif tensor_model_parallel_size is not None and pipeline_model_parallel_size > 1: checkpoint_path = os.path.join( args.checkpoint_folder, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}', args.checkpoint_name) else: checkpoint_path = os.path.join(args.checkpoint_folder, args.checkpoint_name) logging.info(f"loading checkpoint {checkpoint_path}") if args.model_type == 'gpt': ## this dictionary is used to rename the model parameters name_translate = {} name_translate['transformer'] = 'encoder' name_translate['.attention.'] = '.self_attention.' # nemo megatron doesn't have _for_head key name_translate['word_embeddings_for_head'] = 'word_embeddings' checkpoint, consumed, steps, version = load_from_checkpoint( MegatronGPTModel, checkpoint_path, hparams_file=args.hparams_file, trainer=trainer, translator=name_translate, strict=False, ) elif args.model_type == 'bert': ## this dictionary is used to rename the model parameters name_translate = {} name_translate['transformer'] = 'encoder' name_translate['.attention.'] = '.self_attention.' # nemo megatron doesn't have _for_head key name_translate['word_embeddings_for_head'] = 'word_embeddings' checkpoint, consumed, steps, version = load_from_checkpoint( MegatronBertModel, checkpoint_path, hparams_file=args.hparams_file, trainer=trainer, translator=name_translate, strict=False, ) else: raise NotImplemented("{} is not supported".format(args.model_type)) if torch.distributed.is_initialized(): torch.distributed.barrier() if args.output_ckpt_file_path: filepath = args.output_ckpt_file_path base_dir = pathlib.Path(filepath).parent filename_str = pathlib.Path(filepath).name suffix = '.ckpt' content = {} if consumed is not None: content['consumed'] = consumed else: content['consumed'] = 0 if steps is not None: content['steps'] = steps else: content['steps'] = 0 filename = filename_str.format(**content) + suffix checkpoint_path_output = inject_model_parallel_rank( os.path.join(base_dir, filename)) trainer.accelerator.training_type_plugin.checkpoint_io.save_checkpoint( checkpoint, checkpoint_path_output) logging.info( f'NeMo model checkpoint files saved to: {args.output_ckpt_file_path}' ) if args.nemo_file_path: if args.model_type == 'gpt': model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer) elif args.model_type == 'bert': model = load_model(MegatronBertModel, checkpoint, strict=False, trainer=trainer) else: raise NotImplemented("{} is not supported".format(args.model_type)) # verify tensor parallel rank id and pipeline parallel rank id matches assert app_state.data_parallel_size == 1 assert app_state.tensor_model_parallel_size == tensor_model_parallel_size assert app_state.tensor_model_parallel_rank == tensor_rank assert app_state.pipeline_model_parallel_size == pipeline_model_parallel_size assert app_state.pipeline_model_parallel_rank == pipeline_rank model._save_restore_connector = NLPSaveRestoreConnector() model.save_to(args.nemo_file_path) logging.info(f'NeMo model saved to: {args.nemo_file_path}')
def _forward_backward_pipelining_with_interleaving( forward_step_func: FwdStepFunc, batch: List[Optional[Batch]], model: List[torch.nn.Module], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwargs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run interleaved 1F1B schedule with communication between pipeline stages as needed. This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively. This means that model is split into model chunks. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown Note that if `forward_only` this scheduling consists of only warmup phase. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ if not isinstance(model, list): raise RuntimeError("`model` must be a list of `nn.Module`'s'") num_model_chunks: int = len(model) input_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] output_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] curr_iters: List[int] = [0 for _ in range(num_model_chunks)] losses_reduced: List[Union[None, torch.Tensor]] = [] if not forward_only: output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size( ) pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank( ) # Compute number of warmup and remaining microbatches. num_microbatches: int = get_num_microbatches() * num_model_chunks all_warmup_microbatches: bool = False if forward_only: num_warmup_microbatches: int = num_microbatches else: # Run all forward passes and then all backward passes if number of # microbatches is just the number of pipeline stages. # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on # all workers, followed by more microbatches after depending on # stage ID (more forward passes for earlier stages, later stages can # immediately start with 1F1B). if get_num_microbatches() == pipeline_parallel_size: num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") ################################################################################################################### # Helper function definitions. ################################################################################################################### def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: """Helper function to get the model chunk ID given the iteration number.""" pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size( ) microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor: """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if (parallel_state.is_pipeline_first_stage() and len(input_tensors[model_chunk_id]) == len( output_tensors[model_chunk_id])): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, curr_iters[model_chunk_id]), model[model_chunk_id], input_tensor, losses_reduced, dtype, disable_autocast, ) curr_iters[model_chunk_id] += 1 output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor def backward_step_helper(microbatch_id: int) -> torch.Tensor: """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) model_type = get_model_type(model[model_chunk_id]) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if parallel_state.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs) return input_tensor_grad ################################################################################################################### # Run warmup forward passes. ################################################################################################################### parallel_state.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)) _logger.info("Warmup phase") for k in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") output_tensor = forward_step_helper(k, curr_iters) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): if next_forward_model_chunk_id == 0: recv_prev = False if k == (num_microbatches - 1): recv_prev = False _logger.debug( f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}" ) # Don't send tensor downstream if on last stage. if parallel_state.is_pipeline_last_stage(): _logger.debug("Pipeline last stage, not sending tensor downstream") output_tensor = None # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: input_tensor_grad = None recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): recv_next = False _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: _logger.debug("send fwd and receive fwd") input_tensor = p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype) free_output_tensor(output_tensor, deallocate_pipeline_outputs) input_tensors[next_forward_model_chunk_id].append(input_tensor) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for k in range(num_microbatches_remaining): # Forward pass. _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}") forward_k = k + num_warmup_microbatches output_tensor = forward_step_helper(forward_k, curr_iters) # Backward pass. backward_k = k input_tensor_grad = backward_step_helper(backward_k) # Send output_tensor and input_tensor_grad, receive input_tensor # and output_tensor_grad. # Determine if current stage has anything to send in either direction, # otherwise set tensor to None. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank( forward_model_chunk_id) if parallel_state.is_pipeline_last_stage(): output_tensor = None backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) parallel_state.set_virtual_pipeline_model_parallel_rank( backward_model_chunk_id) _logger.debug( f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}" ) if parallel_state.is_pipeline_first_stage(): input_tensor_grad = None # Determine if peers are sending, and where in data structure to put # received tensors. recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): # First stage is ahead of last stage by (pipeline_parallel_size - 1). next_forward_model_chunk_id = get_model_chunk_id( forward_k - (pipeline_parallel_size - 1), forward=True) if next_forward_model_chunk_id == (num_model_chunks - 1): recv_prev = False next_forward_model_chunk_id += 1 else: next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). next_backward_model_chunk_id = get_model_chunk_id( backward_k - (pipeline_parallel_size - 1), forward=False) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 else: next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) # If last iteration, don't receive; we already received one extra # before the start of the for loop. if k == (num_microbatches_remaining - 1): recv_prev = False # Communicate tensors. _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Put input_tensor and output_tensor_grad in data structures in the # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) ################################################################################################################### # Run cooldown backward passes (flush out pipeline). ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks - 1].append( p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)) for k in range(num_microbatches_remaining, num_microbatches): _logger.debug( f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})" ) input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if next_backward_model_chunk_id == (num_model_chunks - 1): recv_next = False if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype)) return losses_reduced
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, model_type: ModelType = ModelType.encoder_or_decoder, *args: Any, **kwargs: Any, ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. model_type: *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if (parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs if model_type == ModelType.encoder_or_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) elif model_type == ModelType.encoder_and_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() # `add_encoder` & `add_decoder` logic. add_encoder, add_decoder = True, True if parallel_state.get_pipeline_model_parallel_world_size() > 1: split_rank = parallel_state.get_pipeline_model_parallel_split_rank( ) if split_rank is None: raise RuntimeError( "Split rank needs to be specified for model with both encoder and decoder." ) rank = parallel_state.get_pipeline_model_parallel_rank() world_size = parallel_state.get_pipeline_model_parallel_world_size( ) pre_process = rank == 0 or rank == split_rank post_process = rank == (split_rank - 1) or rank == (world_size - 1) add_encoder = parallel_state.is_pipeline_stage_before_split() add_decoder = parallel_state.is_pipeline_stage_after_split() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, "add_encoder": add_encoder, "add_decoder": add_decoder, }) model = model_provider_func(*cur_args, **cur_kwargs) model.model_type = model_type if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if (parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0): msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), _calc_number_of_params(model), ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def forward_backward_pipelining_without_interleaving( forward_step_func: FwdStepFunc, batch: Optional[Batch], model: Union[torch.nn.Module, List[torch.nn.Module]], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, decoder_sequence_length: Optional[int] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwawrgs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run non-interleaved 1F1B schedule, with communication between pipeline stages. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown if not forward_only Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. Required for P2P communication. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ # timers = get_timers() model: List[torch.nn.Module] = listify_model(model) if len(model) != 1: msg = f"`model` is expected be a `nn.Module`, but {type(model)}" raise RuntimeError(msg) model: torch.nn.Module = model[0] # Compute number of warmup microbatches. num_microbatches: int = get_num_microbatches() num_warmup_microbatches: int = ( parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1 ) num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches model_type = get_model_type(model) rank: int = parallel_state.get_pipeline_model_parallel_rank() recv_tensor_shapes: List[List[int]] = get_tensor_shapes( rank - 1, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) send_tensor_shapes: List[List[int]] = get_tensor_shapes( rank, model_type, tensor_shape=tensor_shape, decoder_sequence_length=decoder_sequence_length ) _logger.info( f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}" ) # Input, output tensors only need to be saved when doing backward passes input_tensors: List[Union[None, torch.Tensor]] = [] output_tensors: List[Union[None, torch.Tensor]] = [] losses_reduced: List[Union[None, torch.Tensor]] = [] ################################################################################################################### # Run warmup forward passes. ################################################################################################################### _logger.info("Warmup") for i in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug("receive fwd") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i) output_tensor = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: _logger.debug("recv_forward before steady state start") input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for i in range(num_microbatches_remaining): _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") last_iteration: bool = i == (num_microbatches_remaining - 1) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches) output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( forward_step_func, cur_microbatch, model, input_tensor, losses_reduced, dtype, disable_autocast, ) if forward_only: _logger.debug("send fwd") send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) if not last_iteration: _logger.debug("receive fwd (last iteration)") input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send fwd & receive bwd") output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Pop input_tensor and output_tensor from the start of the list for the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) if last_iteration: input_tensor = None _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) else: _logger.debug("send bwd and receive fwd") input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) ################################################################################################################### # Run cooldown backward passes. ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: for i in range(num_warmup_microbatches): _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) _logger.debug("receive bwd") output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs, ) _logger.debug("send bwd") send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) return losses_reduced