def generate( self, inputs: Union[List[str], torch.Tensor, List[dict]], length_params: LengthParam, sampling_params: SamplingParam = None, ) -> OutputType: # check whether the DDP is initialized if parallel_state.is_unitialized(): def dummy(): return if self.trainer.strategy.launcher is not None: self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) self.trainer.strategy.setup_environment() # set the default sampling params if it is None. # default do greedy sampling if sampling_params is None: sampling_params = get_default_sampling_params() # set the default length params if it is None. # default do greedy sampling if length_params is None: length_params = get_default_length_params() return megatron_gpt_generate(self.cuda(), inputs, self.tokenizer, length_params, sampling_params)
def setup_distributed(self, global_rank: int = None, world_size: int = None) -> None: # call PTL init ddp super().setup_distributed() # init model parallel if needed if parallel_state.is_unitialized(): app_state = AppState() if app_state.model_parallel_size is not None: self.init_model_parallel(app_state.global_rank, app_state.world_size)
def generate( self, inputs: Union[List[str], torch.Tensor, List[dict]], length_params: LengthParam, sampling_params: SamplingParam = None, ): # check whether the DDP is initialized if parallel_state.is_unitialized(): def dummy(): return if self.trainer.strategy.launcher is not None: self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) self.trainer.strategy.setup_environment() # set the default sampling params if it is None. # default do greedy sampling if sampling_params is None: sampling_params = get_default_sampling_params() sampling_params["add_BOS"] = self.cfg.data.get("add_bos", False) if length_params is None: length_params = get_default_length_params() # Preprocess inputs to be what they need to be for the generate code dataset = GPTPromptLearningDataset( datasets=inputs, tokenizer=self.tokenizer, virtual_prompt_source=self.virtual_prompt_source, task_templates=self.task_templates, pseudo_tokens=self.pseudo_tokens, pad_token_id=self.pad_token_id, max_seq_length=self.cfg.data.get( 'max_seq_length', self.frozen_model.cfg.max_position_embeddings), min_seq_length=self.cfg.data.get('min_seq_length', 1), add_bos=sampling_params["add_BOS"], add_eos=False, for_train=False, ) task_ids, processed_inputs = dataset.get_all_examples( tokens_to_generate=length_params['max_length']) self.frozen_model.model.parallel_output = False # Call same generate code as in MegatronGPT return megatron_gpt_generate(self.cuda(), processed_inputs, self.tokenizer, length_params, sampling_params, task_ids)
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None): # Check whether the DDP is initialized. This is needed when running inference outside of training loop. if parallel_state.is_unitialized(): def dummy(): return if self.trainer.strategy.launcher is not None: self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) self.trainer.strategy.setup_environment() # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training. _reconfigure_microbatch_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, micro_batch_size=1, # Make sure that there is no "grad acc" while decoding. data_parallel_size=1, # We check above to make sure that dataparallel size is always 1 at inference. ) # If classes that inherit from this class are using a different tokenizer, tokenizer = self.tokenizer if tokenizer is None else tokenizer app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs