def allreduce_word_and_position_embeddings(self): # Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/training.py#L407 # All-reduce word_embeddings' grad across first and last stages to ensure # that word_embeddings parameters stay in sync. # This should only run for models that support pipelined model parallelism # (BERT and GPT-2). if parallel_state.get_pipeline_model_parallel_world_size() > 1 and ( parallel_state.is_rank_in_embedding_group() ): if self.enc_dec_model.share_word_embeddings: word_embeddings_weight = self.enc_dec_model.word_embeddings_weight() if self.megatron_amp_o2: # O2 recipe stores a "main" copy of weights and grads grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) # All reduce position embeddings for T5. if ( parallel_state.is_rank_in_position_embedding_group() and parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.get_pipeline_model_parallel_split_rank() is not None ): position_embeddings_weight = self.enc_dec_model.position_embeddings_weight() if self.megatron_amp_o2: grad = position_embeddings_weight.main_grad else: grad = position_embeddings_weight.grad torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
def initialize_word_embeddings(self, init_method, vocab_size, hidden_size, pipeline_model_parallel_size=1): if not self.share_word_embeddings: raise Exception('initialize_word_embeddings() was called but ' 'share_word_embeddings is false') # TODO: pipeline model parallelism is not implemented in NeMo yet # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline # parallelism there is nothing to do. if pipeline_model_parallel_size == 1: return # Parameters are shared between the word embeddings layer, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if parallel_state.is_pipeline_last_stage(): assert not parallel_state.is_pipeline_first_stage() self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init_method) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Ensure that first and last stages have the same initial parameter # values. if torch.distributed.is_initialized(): if parallel_state.is_pipeline_first_stage( ) or parallel_state.is_pipeline_last_stage(): torch.distributed.all_reduce( self.word_embeddings_weight().data, group=parallel_state.get_embedding_group()) else: print("WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not initialized. " "If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training " "something is definitely wrong.")
def sync_initial_word_embeddings(self): if torch.distributed.is_initialized(): if parallel_state.is_pipeline_first_stage( ) or parallel_state.is_pipeline_last_stage(): torch.distributed.all_reduce( self.word_embeddings_weight().data, group=parallel_state.get_embedding_group()) else: logging.warning( "WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not synchronized. " "If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training " "something is definitely wrong.")
def tab_sample_sequence_batch( model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=True, type_ids=None, temperature=None, ): app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer sizes = tokenizer.code_column.sizes tokens_per_row = sum(sizes) + 1 columns = tokenizer.code_column.columns num_columns = len(columns) tokenid_range = [] for i in range(num_columns): tokenid_range.extend(tokenizer.code_column.get_range(i)) model.eval() with torch.no_grad(): context_length = context_lengths.min().item() context = context_tokens[:, :context_length] # the context may start in the middle of the row, # calculate the offset according to the position of '\n' or '<|endoftext|>' positions = torch.where(context == tokenizer.eor)[1] if len(positions) == 0: positions = torch.where(context == tokenizer.eod)[1] if len(positions) != 0: max_position = positions.max().item() # TODO, need to make sure context of different batch have the same offset lengths") # otherwise, need to calculate offset per batch_id offset = (context_length - max_position - 1) % tokens_per_row else: offset = 0 eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length: maxlen = model.cfg.encoder_seq_length lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) # micro_batch_size = 2 attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() token_in_row = (counter + offset) % tokens_per_row logits = logits.float() logits /= temperature if token_in_row == tokens_per_row - 1: # line break eor_id = tokenizer.eor eod_id = tokenizer.eos_id min_id = min(eor_id, eod_id) max_id = max(eor_id, eod_id) + 1 logits = tab_logits(logits, min_id, max_id) else: # limit the range min_id, max_id = tokenid_range[token_in_row] logits = tab_logits(logits, min_id, max_id) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if output_logits is None: output_context = F.log_softmax( output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output_context, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def sample_sequence_batch( model, context_tokens, context_lengths, task_ids, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None, temperature=None, extra={}, ): # Importing here to avoid circular import errors from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer model.eval() with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None all_generated_indices = None # used to track all generated indices # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length + 1: maxlen = model.cfg.encoder_seq_length + 1 lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) # Only prompt learning models will have a prompt table, and require task ids if isinstance(model, MegatronGPTPromptLearningModel): batch = [ tokens2use, attention_mask_repeat, positions2use, task_ids, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.frozen_model.cfg.hidden_size ] else: batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() # make sure it will generate at least min_length min_length = extra.get('min_tokens_to_generate', 0) if min_length > 0: within_min_length = (context_length - context_lengths) < min_length logits[within_min_length, eod_id] = -float('Inf') # make sure it won't sample outside the vocab_size range logits[:, tokenizer.vocab_size:] = -float('Inf') if extra.get('greedy', False): prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= temperature # handle repetition penality logits = repetition_penalty( logits, extra.get('repetition_penalty', 1.2), all_generated_indices) logits = top_k_logits(logits, top_k=extra.get('top_k', 0), top_p=extra.get('top_p', 0.9)) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the predicted out of vocabulary tokens prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) # Replace sampled tokens w/ done token if EOD has already been sampled new_tokens = switch(new_tokens, eod_id, is_done) # Replace special soft prompt token ids with unk token ids if isinstance(model, MegatronGPTPromptLearningModel): pseudo_token_ids_start = model.pseudo_token_ids_start new_tokens[(new_tokens >= pseudo_token_ids_start)] = tokenizer.unk_id tokens[:, :context_length][( tokens[:, :context_length] >= pseudo_token_ids_start)] = tokenizer.unk_id # Insert either new predicted or next prompt token tokens[:, context_length] = new_tokens if output_logits is None: output = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output, 2, indices).squeeze(2) all_generated_indices = indices[:, :, 0] if all_probs: full_logits = output else: output = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) all_generated_indices = torch.cat( [all_generated_indices, indices[:, :, 0]], 1) if all_probs: full_logits = torch.cat([full_logits, output], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def synced_generate( model, context_tokens_tensor, context_length_tensor, task_ids, tokens_to_generate, all_probs, temperature, top_k=0, top_p=0.0, greedy=False, repetition_penalty=1.2, min_tokens_to_generate=0, ): context_length = context_length_tensor.min().item() tokenizer = model.tokenizer tokens, attention_mask, position_ids = get_batch(model, tokenizer, context_tokens_tensor) if isinstance(tokenizer, TabularTokenizer): batch_token_iterator = tab_sample_sequence_batch( model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokens_to_generate, all_probs, temperature=temperature, ) else: batch_token_iterator = sample_sequence_batch( model, context_tokens_tensor, context_length_tensor, task_ids, attention_mask, position_ids, tokens_to_generate, all_probs, temperature=temperature, extra={ "top_p": top_p, "top_k": top_k, "greedy": greedy, "repetition_penalty": repetition_penalty, "min_tokens_to_generate": min_tokens_to_generate, }, ) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 if parallel_state.is_pipeline_last_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) if all_probs: src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() output_logits = torch.empty(tokens.size(0), context_length - 1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if all_probs: src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() full_logits = torch.empty( tokens.size(0), context_length - 1, model.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"), ) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits, full_logits