def initialize_word_embeddings(self, init_method, vocab_size, hidden_size, pipeline_model_parallel_size=1): if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # TODO: pipeline model parallelism is not implemented in NeMo yet # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline # parallelism there is nothing to do. if pipeline_model_parallel_size == 1: return # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if parallel_state.is_pipeline_last_stage(): assert not parallel_state.is_pipeline_first_stage() self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init_method) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter # values. if torch.distributed.is_initialized(): if parallel_state.is_pipeline_first_stage( ) or parallel_state.is_pipeline_last_stage(): torch.distributed.all_reduce( self.word_embeddings_weight().data, group=parallel_state.get_embedding_group()) else: print("WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not initialized. " "If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training " "something is definitely wrong.")
def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor: """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if (parallel_state.is_pipeline_first_stage() and len(input_tensors[model_chunk_id]) == len( output_tensors[model_chunk_id])): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, curr_iters[model_chunk_id]), model[model_chunk_id], input_tensor, losses_reduced, dtype, disable_autocast, ) curr_iters[model_chunk_id] += 1 output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor
def recv_forward( tensor_shape: Shape, override_scatter_gather_tensors_in_pipeline: bool = False, *, dtype: Optional[torch.dtype] = None, timers: _Timers = None, async_comm: bool = False, ) -> Union[torch.Tensor, FutureTensor, None]: """Receive tensor from previous rank in pipeline (forward receive).""" if parallel_state.is_pipeline_first_stage(): return None # if timers is not None: # timers("forward-recv").start() input_tensor, _ = _communicate( tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False, tensor_shape=tensor_shape, override_scatter_gather_tensors_in_pipeline= override_scatter_gather_tensors_in_pipeline, dtype_=dtype, async_comm=async_comm, ) # if timers is not None: # timers("forward-recv").stop() return input_tensor
def send_backward_recv_forward( input_tensor_grad: torch.Tensor, tensor_shape: Shape, *, dtype: Optional[torch.dtype] = None, timers: _Timers = None, async_comm: bool = False, ) -> Union[torch.Tensor, FutureTensor, None]: """Batched send and recv with previous rank in pipeline.""" if parallel_state.is_pipeline_first_stage(): return None # if timers is not None: # timers("backward-send-forward-recv").start() input_tensor, _ = _communicate( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False, tensor_shape=tensor_shape, dtype_=dtype, async_comm=async_comm, ) # if timers is not None: # timers("backward-send-forward-recv").stop() return input_tensor
def forward(self, *inputs, **kwargs): if parallel_state.is_pipeline_first_stage(): inputs = fp32_to_float16(inputs, self.float16_converter) outputs = self.module(*inputs, **kwargs) if parallel_state.is_pipeline_last_stage(): outputs = float16_to_fp32(outputs) return outputs
def test_send_backward_recv_backward(self): self._init_model_parallel() tensor = self.create_tensor(self.rank) next_tensor = None if parallel_state.is_pipeline_first_stage(): next_tensor = p2p_communication.recv_backward( tensor_shape=self.shape, dtype=self.dtype) elif parallel_state.is_pipeline_last_stage(): p2p_communication.send_backward(input_tensor_grad=tensor, tensor_shape=self.shape, dtype=self.dtype) else: next_tensor = p2p_communication.send_backward_recv_backward( input_tensor_grad=tensor, recv_next=True, tensor_shape=self.shape, dtype=self.dtype, ) if parallel_state.is_pipeline_last_stage(): self.assertIsNone(next_tensor) else: expected_next_tensor = self.create_tensor(self.rank + 1) self.assertEqual(next_tensor, expected_next_tensor)
def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline # parallelism there is nothing to do. if parallel_state.get_pipeline_model_parallel_world_size() == 1: return # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if parallel_state.is_pipeline_last_stage(): assert not parallel_state.is_pipeline_first_stage() self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init_method) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
def word_embeddings_weight(self): if parallel_state.is_pipeline_first_stage(ignore_virtual=True): return self.language_model.embedding.word_embeddings.weight if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if not self.share_word_embeddings: raise Exception( 'word_embeddings_weight() called for last ' 'stage, but share_word_embeddings is false' ) return self.word_embeddings.weight raise Exception('word_embeddings_weight() should be ' 'called for first and last stage only')
def test_no_interleaving_warmup(self): self.assertEqual(self.world_size, 2) self._init_model_parallel() input_tensor = None if parallel_state.is_pipeline_first_stage(): tensor = self.create_tensor(self.rank) print(tensor) p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype) else: input_tensor = p2p_communication.recv_forward( tensor_shape=self.shape, dtype=self.dtype) if parallel_state.is_pipeline_first_stage(): self.assertIsNone(input_tensor) else: expected_input_tensor = self.create_tensor(self.rank - 1) self.assertEqual(input_tensor, expected_input_tensor)
def test_send_forward_recv_forward(self): self._init_model_parallel() prev_tensor = None tensor = self.create_tensor(self.rank) if parallel_state.is_pipeline_first_stage(): p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype) elif parallel_state.is_pipeline_last_stage(): prev_tensor = p2p_communication.recv_forward( tensor_shape=self.shape, dtype=self.dtype) else: prev_tensor = p2p_communication.send_forward_recv_forward( output_tensor=tensor, recv_prev=True, tensor_shape=self.shape, dtype=self.dtype, ) if parallel_state.is_pipeline_first_stage(): self.assertIsNone(prev_tensor) else: expected_prev_tensor = self.create_tensor(self.rank - 1) self.assertEqual(prev_tensor, expected_prev_tensor)
def sync_initial_word_embeddings(self): if torch.distributed.is_initialized(): if parallel_state.is_pipeline_first_stage( ) or parallel_state.is_pipeline_last_stage(): torch.distributed.all_reduce( self.word_embeddings_weight().data, group=parallel_state.get_embedding_group()) else: logging.warning( "WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not synchronized. " "If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training " "something is definitely wrong.")
def allreduce_first_last_embeddings(self): # Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/training.py#L407 # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). if parallel_state.get_pipeline_model_parallel_world_size() > 1 and ( parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage()): if self.model.share_word_embeddings: word_embeddings_weight = self.model.word_embeddings_weight() if self.megatron_amp_o2: # O2 recipe stores a "main" copy of weights and grads grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce( grad, group=parallel_state.get_embedding_group())
def send_backward( input_tensor_grad: torch.Tensor, tensor_shape: Shape, *, dtype: Optional[torch.dtype] = None, timers: _Timers = None, ) -> None: """Send tensor to previous rank in pipeline (backward send).""" if parallel_state.is_pipeline_first_stage(): return # if timers is not None: # timers("backward-send").start() _communicate( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=False, tensor_shape=tensor_shape, dtype_=dtype, )
def send_backward_recv_forward( input_tensor_grad: torch.Tensor, tensor_shape: Shape, *, dtype: Optional[torch.dtype] = None, timers: _Timers = None, ) -> torch.Tensor: """Batched send and recv with previous rank in pipeline.""" if parallel_state.is_pipeline_first_stage(): return None if timers is not None: timers("backward-send-forward-recv").start() input_tensor, _ = _communicate( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False, tensor_shape=tensor_shape, dtype_=_get_current_dtype(dtype), ) if timers is not None: timers("backward-send-forward-recv").stop() return input_tensor
def synced_generate( model, context_tokens_tensor, context_length_tensor, task_ids, tokens_to_generate, all_probs, temperature, top_k=0, top_p=0.0, greedy=False, repetition_penalty=1.2, min_tokens_to_generate=0, ): context_length = context_length_tensor.min().item() tokenizer = model.tokenizer tokens, attention_mask, position_ids = get_batch(model, tokenizer, context_tokens_tensor) if isinstance(tokenizer, TabularTokenizer): batch_token_iterator = tab_sample_sequence_batch( model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokens_to_generate, all_probs, temperature=temperature, ) else: batch_token_iterator = sample_sequence_batch( model, context_tokens_tensor, context_length_tensor, task_ids, attention_mask, position_ids, tokens_to_generate, all_probs, temperature=temperature, extra={ "top_p": top_p, "top_k": top_k, "greedy": greedy, "repetition_penalty": repetition_penalty, "min_tokens_to_generate": min_tokens_to_generate, }, ) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 if parallel_state.is_pipeline_last_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) if all_probs: src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() output_logits = torch.empty(tokens.size(0), context_length - 1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if all_probs: src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() full_logits = torch.empty( tokens.size(0), context_length - 1, model.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"), ) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits, full_logits
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, *args, **kwargs ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if ( parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None ): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if parallel_state.get_data_parallel_rank() == 0: msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def build_model( model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], wrap_with_ddp: bool = True, virtual_pipeline_model_parallel_size: Optional[int] = None, model_type: ModelType = ModelType.encoder_or_decoder, *args: Any, **kwargs: Any, ) -> List[torch.nn.Module]: """Build the model satisfying pipeline model parallel requirements. This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to `model_provider_func`. Args: model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. wrap_with_ddp: If :obj:`True`, wrap the instantiated model with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. model_type: *args: arguments for model provider func **kwargs: Keyword arguments for model provider func Returns: a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, the list has multiple models, otherwise one. """ if (parallel_state.get_pipeline_model_parallel_world_size() > 1 and virtual_pipeline_model_parallel_size is not None): model = [] for i in range(virtual_pipeline_model_parallel_size): cur_args = args cur_kwargs = kwargs parallel_state.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) this_model = model_provider_func(*cur_args, **cur_kwargs) model.append(this_model) else: cur_args = args cur_kwargs = kwargs if model_type == ModelType.encoder_or_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, }) model = model_provider_func(*cur_args, **cur_kwargs) elif model_type == ModelType.encoder_and_decoder: pre_process = parallel_state.is_pipeline_first_stage() post_process = parallel_state.is_pipeline_last_stage() # `add_encoder` & `add_decoder` logic. add_encoder, add_decoder = True, True if parallel_state.get_pipeline_model_parallel_world_size() > 1: split_rank = parallel_state.get_pipeline_model_parallel_split_rank( ) if split_rank is None: raise RuntimeError( "Split rank needs to be specified for model with both encoder and decoder." ) rank = parallel_state.get_pipeline_model_parallel_rank() world_size = parallel_state.get_pipeline_model_parallel_world_size( ) pre_process = rank == 0 or rank == split_rank post_process = rank == (split_rank - 1) or rank == (world_size - 1) add_encoder = parallel_state.is_pipeline_stage_before_split() add_decoder = parallel_state.is_pipeline_stage_after_split() cur_kwargs.update({ "pre_process": pre_process, "post_process": post_process, "add_encoder": add_encoder, "add_decoder": add_decoder, }) model = model_provider_func(*cur_args, **cur_kwargs) model.model_type = model_type if not isinstance(model, list): model = [model] # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if (parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0): msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), _calc_number_of_params(model), ) print(msg, flush=True) # GPU allocation. for model_module in model: model_module.cuda(torch.cuda.current_device()) if wrap_with_ddp: i = torch.cuda.current_device() model = [ torch.nn.parallel.distributed.DistributedDataParallel( model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), ) for model_module in model ] return model
def _forward_backward_pipelining_with_interleaving( forward_step_func: FwdStepFunc, batch: List[Optional[Batch]], model: List[torch.nn.Module], *, forward_only: bool, tensor_shape: Optional[Union[List[int], torch.Size]] = None, dtype: Optional[torch.dtype] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, disable_autocast: bool = False, deallocate_pipeline_outputs: bool = False, **kwargs, ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: """Run interleaved 1F1B schedule with communication between pipeline stages as needed. This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively. This means that model is split into model chunks. This pipeline parallel scheduling consists of three steps: 1. warmup 2. 1F1B a.k.a. steady state 3. cooldown Note that if `forward_only` this scheduling consists of only warmup phase. Args: forward_step_func: A function which takes a minibatch and model as its arguments and returns model's forward output and the loss function. The loss function is supposed to take one `torch.Tensor` and return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. batch: A minibatch, i.e., a list of `torch.Tensor`'s. model: A `torch.nn.Module` or a list of `torch.nn.Module`. Keyword args: forward_only: tensor_shape: Shape of tensor. dtype: dtype used in p2p communication. If ``None`` (default value), torch.float32 will be used even if ``autocast`` is enabled. grad_scaler: disable_autocast: deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of each pipeline stage. Experimental. Returns: a list of loss `torch.Tensor`s if the last stage, empty list otherwise. """ if not isinstance(model, list): raise RuntimeError("`model` must be a list of `nn.Module`'s'") num_model_chunks: int = len(model) input_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] output_tensors: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] curr_iters: List[int] = [0 for _ in range(num_model_chunks)] losses_reduced: List[Union[None, torch.Tensor]] = [] if not forward_only: output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [ [] for _ in range(num_model_chunks) ] pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size( ) pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank( ) # Compute number of warmup and remaining microbatches. num_microbatches: int = get_num_microbatches() * num_model_chunks all_warmup_microbatches: bool = False if forward_only: num_warmup_microbatches: int = num_microbatches else: # Run all forward passes and then all backward passes if number of # microbatches is just the number of pipeline stages. # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on # all workers, followed by more microbatches after depending on # stage ID (more forward passes for earlier stages, later stages can # immediately start with 1F1B). if get_num_microbatches() == pipeline_parallel_size: num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches _logger.info(f"num_microbatches: {num_microbatches}, " f"num_warmup_microbatches: {num_warmup_microbatches}, " f"num_microbatches_remaining: {num_microbatches_remaining}") ################################################################################################################### # Helper function definitions. ################################################################################################################### def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: """Helper function to get the model chunk ID given the iteration number.""" pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size( ) microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor: """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # forward step if (parallel_state.is_pipeline_first_stage() and len(input_tensors[model_chunk_id]) == len( output_tensors[model_chunk_id])): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] output_tensor = forward_step( forward_step_func, get_kth_microbatch(batch, curr_iters[model_chunk_id]), model[model_chunk_id], input_tensor, losses_reduced, dtype, disable_autocast, ) curr_iters[model_chunk_id] += 1 output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: input_tensors[model_chunk_id].pop() output_tensors[model_chunk_id].pop() return output_tensor def backward_step_helper(microbatch_id: int) -> torch.Tensor: """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()). """ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) model_type = get_model_type(model[model_chunk_id]) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) if parallel_state.is_pipeline_last_stage(): if len(output_tensor_grads[model_chunk_id]) == 0: output_tensor_grads[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs) return input_tensor_grad ################################################################################################################### # Run warmup forward passes. ################################################################################################################### parallel_state.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)) _logger.info("Warmup phase") for k in range(num_warmup_microbatches): _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") output_tensor = forward_step_helper(k, curr_iters) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): if next_forward_model_chunk_id == 0: recv_prev = False if k == (num_microbatches - 1): recv_prev = False _logger.debug( f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}" ) # Don't send tensor downstream if on last stage. if parallel_state.is_pipeline_last_stage(): _logger.debug("Pipeline last stage, not sending tensor downstream") output_tensor = None # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: input_tensor_grad = None recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): recv_next = False _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: _logger.debug("send fwd and receive fwd") input_tensor = p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype) free_output_tensor(output_tensor, deallocate_pipeline_outputs) input_tensors[next_forward_model_chunk_id].append(input_tensor) ################################################################################################################### # Run 1F1B in steady state. ################################################################################################################### _logger.info("Steady phase") for k in range(num_microbatches_remaining): # Forward pass. _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}") forward_k = k + num_warmup_microbatches output_tensor = forward_step_helper(forward_k, curr_iters) # Backward pass. backward_k = k input_tensor_grad = backward_step_helper(backward_k) # Send output_tensor and input_tensor_grad, receive input_tensor # and output_tensor_grad. # Determine if current stage has anything to send in either direction, # otherwise set tensor to None. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank( forward_model_chunk_id) if parallel_state.is_pipeline_last_stage(): output_tensor = None backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) parallel_state.set_virtual_pipeline_model_parallel_rank( backward_model_chunk_id) _logger.debug( f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}" ) if parallel_state.is_pipeline_first_stage(): input_tensor_grad = None # Determine if peers are sending, and where in data structure to put # received tensors. recv_prev = True if parallel_state.is_pipeline_first_stage(ignore_virtual=True): # First stage is ahead of last stage by (pipeline_parallel_size - 1). next_forward_model_chunk_id = get_model_chunk_id( forward_k - (pipeline_parallel_size - 1), forward=True) if next_forward_model_chunk_id == (num_model_chunks - 1): recv_prev = False next_forward_model_chunk_id += 1 else: next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). next_backward_model_chunk_id = get_model_chunk_id( backward_k - (pipeline_parallel_size - 1), forward=False) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 else: next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) # If last iteration, don't receive; we already received one extra # before the start of the for loop. if k == (num_microbatches_remaining - 1): recv_prev = False # Communicate tensors. _logger.debug("send fwd&bwd and receive fwd&bwd") ( input_tensor, output_tensor_grad, ) = p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, ) free_output_tensor(output_tensor, deallocate_pipeline_outputs) # Put input_tensor and output_tensor_grad in data structures in the # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) ################################################################################################################### # Run cooldown backward passes (flush out pipeline). ################################################################################################################### _logger.info("Cooldown phase") if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks - 1].append( p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)) for k in range(num_microbatches_remaining, num_microbatches): _logger.debug( f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})" ) input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if next_backward_model_chunk_id == (num_model_chunks - 1): recv_next = False if k == (num_microbatches - 1): recv_next = False output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype)) return losses_reduced
def _forward_backward_test_impl( self, forward_only: bool, fwd_bwd_func: FwdStepFunc, pipeline_model_parallel_world_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int], async_comm: bool = False, *, default_backend: Optional[str] = None, p2p_backend: Optional[str] = None, ) -> None: if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: self.assertIsNotNone(virtual_pipeline_model_parallel_size) self.assertGreater(virtual_pipeline_model_parallel_size, 1) dtype_options = self.dtypes or [torch.float32, torch.double ] + _get_autocast_dtypes() for dtype, deallocate_pipeline_outputs in itertools.product( dtype_options, self.deallocate_options, ): grad_scaler = (torch.cuda.amp.GradScaler( init_scale=4.0) if dtype == torch.half else None) (tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size ) = _get_default_world_sizes_model_parallel_world_size( pipeline_model_parallel_world_size) parallel_state.initialize_model_parallel( tensor_model_parallel_size_=tensor_model_parallel_world_size, pipeline_model_parallel_size_= pipeline_model_parallel_world_size, virtual_pipeline_model_parallel_size_= virtual_pipeline_model_parallel_size, default_backend=default_backend, p2p_backend=p2p_backend, ) pp_utils._reconfigure_microbatch_calculator( rank=parallel_state.get_tensor_model_parallel_rank(), rampup_batch_size=None, global_batch_size=self.GLOBAL_BATCH_SIZE, micro_batch_size=self.MICRO_BATCH_SIZE, data_parallel_size=parallel_state.get_data_parallel_world_size( ), ) global_batch_shape = ( self.GLOBAL_BATCH_SIZE // parallel_state.get_data_parallel_world_size(), self.HIDDEN_SIZE, self.HIDDEN_SIZE, ) batch = None if parallel_state.is_pipeline_first_stage(): batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), ) model = build_model( testing_utils.model_provider_func, # Use DDP only when it's better to have wrap_with_ddp=data_parallel_size > 1, virtual_pipeline_model_parallel_size= virtual_pipeline_model_parallel_size, hidden_size=self.HIDDEN_SIZE, ) offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 for idx, model_module in enumerate(model): model_module = model_module.to(dtype) model_module.apply(get_init_weights_func(idx * offset)) _param_groups = _get_params_for_weight_decay_optimization(model) optimizer = torch.optim.Adam(_param_groups, lr=1e-3) pp_utils.update_num_microbatches(0) loss = fwd_bwd_func( testing_utils.fwd_step_func, batch, model, forward_only=forward_only, # `tensor_shape` is the shape of micro batch. tensor_shape=( self.MICRO_BATCH_SIZE, self.HIDDEN_SIZE, self.HIDDEN_SIZE, ), dtype=dtype, async_comm=async_comm, grad_scaler=grad_scaler, deallocate_pipeline_output=deallocate_pipeline_outputs, ) if dtype == torch.double: hidden_size = self.HIDDEN_SIZE microbatch_size = self.MICRO_BATCH_SIZE total_layers = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None: total_layers *= virtual_pipeline_model_parallel_size target_loss, target_model = get_target_loss_and_model( global_batch_shape, hidden_size, total_layers) for loss_item in loss: x = loss_item['avg'] torch.testing.assert_close(x.item() / microbatch_size, target_loss.item()) if not forward_only: for vm_id, model_module in enumerate(model): params = list(model_module.parameters()) rank = params[0].get_device() offset = pipeline_model_parallel_world_size param_id = rank // data_parallel_size + vm_id * offset target_params = target_model[param_id] torch.testing.assert_close(params[0].cpu(), target_params[0]) torch.testing.assert_close(params[1].cpu(), target_params[1]) torch.testing.assert_close( params[0].grad.cpu() / microbatch_size, target_params[0].grad) torch.testing.assert_close( params[1].grad.cpu() / microbatch_size, target_params[1].grad) if not forward_only: for m in model: for p in m.parameters(): self.assertIsNotNone(p.grad) optimizer.step() optimizer.zero_grad(set_to_none=True) parallel_state.destroy_model_parallel()
def tab_sample_sequence_batch( model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=True, type_ids=None, temperature=None, ): app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer sizes = tokenizer.code_column.sizes tokens_per_row = sum(sizes) + 1 columns = tokenizer.code_column.columns num_columns = len(columns) tokenid_range = [] for i in range(num_columns): tokenid_range.extend(tokenizer.code_column.get_range(i)) model.eval() with torch.no_grad(): context_length = context_lengths.min().item() context = context_tokens[:, :context_length] # the context may start in the middle of the row, # calculate the offset according to the position of '\n' or '<|endoftext|>' positions = torch.where(context == tokenizer.eor)[1] if len(positions) == 0: positions = torch.where(context == tokenizer.eod)[1] if len(positions) != 0: max_position = positions.max().item() # TODO, need to make sure context of different batch have the same offset lengths") # otherwise, need to calculate offset per batch_id offset = (context_length - max_position - 1) % tokens_per_row else: offset = 0 eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length: maxlen = model.cfg.encoder_seq_length lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) # micro_batch_size = 2 attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() token_in_row = (counter + offset) % tokens_per_row logits = logits.float() logits /= temperature if token_in_row == tokens_per_row - 1: # line break eor_id = tokenizer.eor eod_id = tokenizer.eos_id min_id = min(eor_id, eod_id) max_id = max(eor_id, eod_id) + 1 logits = tab_logits(logits, min_id, max_id) else: # limit the range min_id, max_id = tokenid_range[token_in_row] logits = tab_logits(logits, min_id, max_id) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if output_logits is None: output_context = F.log_softmax( output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output_context, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def main(cfg) -> None: # trainer required for restoring model parallel models trainer = Trainer(plugins=NLPDDPPlugin(), **cfg.trainer) assert ( cfg.trainer.devices * cfg.trainer.num_nodes == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" # Load prompt tuned model, virtual_prompt_model_file must be provided in config if cfg.get('virtual_prompt_model_file', None) is not None: # Update frozen GPT model path in case it has changed prompt_learning_cfg = MegatronGPTPromptLearningModel.restore_from( cfg.virtual_prompt_model_file, trainer=trainer, return_config=True) with open_dict(prompt_learning_cfg): prompt_learning_cfg.language_model_path = cfg.gpt_model_file # Now load prompt learning model with frozen gpt model base model = MegatronGPTPromptLearningModel.restore_from( restore_path=cfg.virtual_prompt_model_file, trainer=trainer, override_config_path=prompt_learning_cfg) # Or load regular GPT model elif cfg.gpt_model_file: model = MegatronGPTModel.restore_from(restore_path=cfg.gpt_model_file, trainer=trainer) elif cfg.checkpoint_dir: app_state = AppState() if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, tensor_model_parallel_size_=cfg.tensor_model_parallel_size, pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=cfg. pipeline_model_parallel_split_rank, ) checkpoint_path = inject_model_parallel_rank( os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) model = MegatronGPTModel.load_from_checkpoint( checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) else: raise ValueError("need at least a nemo file or checkpoint dir") model.freeze() # Have to turn off activations_checkpoint_method for inference try: model.model.language_model.encoder.activations_checkpoint_method = None except AttributeError: pass try: model.frozen_model.language_model.encoder.activations_checkpoint_method = None except AttributeError: pass length_params: LengthParam = { "max_length": cfg.inference.tokens_to_generate, "min_length": cfg.inference.min_tokens_to_generate, } sampling_params: SamplingParam = { "use_greedy": cfg.inference.greedy, "temperature": cfg.inference.temperature, "top_k": cfg.inference.top_k, "top_p": cfg.inference.top_p, "repetition_penalty": cfg.inference.repetition_penalty, "add_BOS": cfg.inference.add_BOS, "all_probs": cfg.inference.all_probs, "compute_logprob": cfg.inference.compute_logprob, } # First method of running text generation, call model.generate method response = model.generate(inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params) print("***************************") print(response) print("***************************") # Second method of running text generation, call trainer.predict collate_fn = None if cfg.get('virtual_prompt_model', False): collate_fn = lambda x: list(x) ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) request_dl = DataLoader(dataset=ds, collate_fn=collate_fn, batch_size=2) config = OmegaConf.to_container(cfg.inference) model.set_inference_config(config) response = trainer.predict(model, request_dl) print("***************************") print(response) print("***************************") # Third method of running text generation, use inference server if cfg.server: if parallel_state.is_pipeline_first_stage( ) and parallel_state.get_tensor_model_parallel_rank() == 0: server = MegatronServer(model.cuda()) server.run("0.0.0.0", port=cfg.port) while True: choice = torch.cuda.LongTensor(1) torch.distributed.broadcast(choice, 0) if choice[0].item() == 0: generate(model.cuda())
def sample_sequence_batch( model, context_tokens, context_lengths, task_ids, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None, temperature=None, extra={}, ): # Importing here to avoid circular import errors from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer model.eval() with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None all_generated_indices = None # used to track all generated indices # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length + 1: maxlen = model.cfg.encoder_seq_length + 1 lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) # Only prompt learning models will have a prompt table, and require task ids if isinstance(model, MegatronGPTPromptLearningModel): batch = [ tokens2use, attention_mask_repeat, positions2use, task_ids, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.frozen_model.cfg.hidden_size ] else: batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() # make sure it will generate at least min_length min_length = extra.get('min_tokens_to_generate', 0) if min_length > 0: within_min_length = (context_length - context_lengths) < min_length logits[within_min_length, eod_id] = -float('Inf') # make sure it won't sample outside the vocab_size range logits[:, tokenizer.vocab_size:] = -float('Inf') if extra.get('greedy', False): prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= temperature # handle repetition penality logits = repetition_penalty( logits, extra.get('repetition_penalty', 1.2), all_generated_indices) logits = top_k_logits(logits, top_k=extra.get('top_k', 0), top_p=extra.get('top_p', 0.9)) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the predicted out of vocabulary tokens prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) # Replace sampled tokens w/ done token if EOD has already been sampled new_tokens = switch(new_tokens, eod_id, is_done) # Replace special soft prompt token ids with unk token ids if isinstance(model, MegatronGPTPromptLearningModel): pseudo_token_ids_start = model.pseudo_token_ids_start new_tokens[(new_tokens >= pseudo_token_ids_start)] = tokenizer.unk_id tokens[:, :context_length][( tokens[:, :context_length] >= pseudo_token_ids_start)] = tokenizer.unk_id # Insert either new predicted or next prompt token tokens[:, context_length] = new_tokens if output_logits is None: output = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output, 2, indices).squeeze(2) all_generated_indices = indices[:, :, 0] if all_probs: full_logits = output else: output = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) all_generated_indices = torch.cat( [all_generated_indices, indices[:, :, 0]], 1) if all_probs: full_logits = torch.cat([full_logits, output], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break