def decode(self, tokens_enc, enc_mask, num_tokens_to_generate): encoder_hidden_states = self( encoder_input_ids=tokens_enc, decoder_input_ids=None, encoder_attn_mask=enc_mask, decoder_attn_mask=None, encoder_decoder_attn_mask=None, tokentype_ids=None, lm_labels=None, enc_hidden_states=None, output_enc_hidden_only=True, ) predicted_tokens_dec = torch.LongTensor( [self.tokenizer.bos_id]).unsqueeze(0).to(tokens_enc.device) for _ in range(num_tokens_to_generate): # Overwrite the decoder token since we want to predict enc_dec_mask = self.make_inference_attention_mask_3d( predicted_tokens_dec, tokens_enc, self.tokenizer.pad_id) dec_mask = self.make_inference_attention_mask_3d( predicted_tokens_dec, predicted_tokens_dec, self.tokenizer.pad_id) dec_mask = dec_mask * self.make_inference_history_mask_3d( predicted_tokens_dec) enc_dec_mask = enc_dec_mask < 0.5 dec_mask = dec_mask < 0.5 output_tensor, _ = self( encoder_input_ids=tokens_enc, decoder_input_ids=predicted_tokens_dec, encoder_attn_mask=enc_mask, decoder_attn_mask=dec_mask, encoder_decoder_attn_mask=enc_dec_mask, tokentype_ids=None, lm_labels=None, enc_hidden_states=encoder_hidden_states, output_enc_hidden_only=False, ) output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( output_tensor) log_probs, token_ids = torch.max(nn.functional.log_softmax( output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec, token_ids[:, -1].unsqueeze(1)], 1) if token_ids[:, -1] == self.tokenizer.eos_id: break return predicted_tokens_dec, log_probs
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" # Parallel logits. input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region( input_) # Matrix multiply. if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight) else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) # Gather if needed. if parallel_output: return logits_parallel return tensor_parallel.gather_from_tensor_model_parallel_region( logits_parallel)
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, enc_input=None): # TODO: move method into a class inside MegatronTokenLevelEncoderDecoderModule (?) encoder_hidden_states, enc_output_mask = itemgetter("enc_output", "enc_output_mask")( self( encoder_input_ids=tokens_enc, decoder_input_ids=None, encoder_attn_mask=enc_mask, decoder_attn_mask=None, tokentype_ids=None, lm_labels=None, enc_hidden_states=None, enc_output_mask=None, output_enc_hidden_only=True, enc_input=enc_input, ) ) predicted_tokens_dec = ( torch.LongTensor([self.tokenizer.bos_id] * tokens_enc.size(0)).unsqueeze(1).to(tokens_enc.device) ) for _ in range(num_tokens_to_generate): dec_mask = predicted_tokens_dec != self.tokenizer.pad_id token_logits = itemgetter("token_logits")( self( encoder_input_ids=tokens_enc, decoder_input_ids=predicted_tokens_dec, encoder_attn_mask=enc_mask, decoder_attn_mask=dec_mask, tokentype_ids=None, lm_labels=None, enc_hidden_states=encoder_hidden_states, enc_output_mask=enc_output_mask, output_enc_hidden_only=False, enc_input=enc_input, ) ) token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) log_probs, token_ids = torch.max(nn.functional.log_softmax(token_logits, dim=-1), dim=-1) predicted_tokens_dec = torch.cat([predicted_tokens_dec, token_ids[:, -1].unsqueeze(1)], 1) return predicted_tokens_dec, log_probs
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None): app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([self.tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != self.tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def decode(self, enc_query, enc_taskname, label_position, num_tokens_to_generate): with torch.no_grad(): predicted_tokens_dec = enc_query label_start = label_position[:, 0].clone() for _ in range(num_tokens_to_generate): attn_mask = make_attention_mask_3d(predicted_tokens_dec, predicted_tokens_dec, self.pad_token_id) attn_mask = attn_mask * make_history_mask_3d( predicted_tokens_dec) attn_mask = attn_mask < 0.5 attn_mask = attn_mask.unsqueeze(1) input_embeds = self.embed_input(predicted_tokens_dec, enc_taskname) encoder_position_ids = build_position_ids(predicted_tokens_dec) position_embeddings = self.model.model.language_model.embedding.position_embeddings( encoder_position_ids) encoder_input = input_embeds + position_embeddings if self.float_type == torch.float32: output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=attn_mask, ) else: with torch.autocast(device_type="cuda", dtype=self.float_type): output = self.model.model( None, None, encoder_input=encoder_input, attention_mask=attn_mask, ) output_tensor = output output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( output_tensor) # TODO, add logic to use the allowed labels if it is defined log_probs, token_ids = torch.max(nn.functional.log_softmax( output_tensor, dim=-1), dim=-1) new_pred = torch.full_like(token_ids[:, 0:1], self.pad_token_id) predicted_tokens_dec = torch.cat( [predicted_tokens_dec, new_pred], 1) predicted = torch.gather(token_ids, 1, label_start.view(-1, 1)) # need to scatter the token id at the right position label_start += 1 predicted_tokens_dec.scatter_(1, label_start.view(-1, 1), predicted) return predicted_tokens_dec, log_probs
def decode(self, tokens_enc, enc_mask, num_tokens_to_generate, encoder_input=None, tokenizer=None): # Check whether the DDP is initialized. This is needed when running inference outside of training loop. if parallel_state.is_unitialized(): def dummy(): return if self.trainer.strategy.launcher is not None: self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) self.trainer.strategy.setup_environment() # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training. _reconfigure_microbatch_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, micro_batch_size=1, # Make sure that there is no "grad acc" while decoding. data_parallel_size=1, # We check above to make sure that dataparallel size is always 1 at inference. ) # If classes that inherit from this class are using a different tokenizer, tokenizer = self.tokenizer if tokenizer is None else tokenizer app_state = AppState() global_batch_per_gpu = tokens_enc.size(0) num_micro_batches_before_decode = get_num_microbatches() # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # TODO: reconfigure back to how things were before decode? _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu, # Make sure that there is no "grad acc" while decoding. data_parallel_size=parallel_state.get_data_parallel_world_size(), ) predicted_tokens_dec = ( torch.LongTensor([tokenizer.bos_id] * global_batch_per_gpu).unsqueeze(1).to(tokens_enc.device) ) encoder_seq_length = tokens_enc.size(1) tensor_shape = [encoder_seq_length, global_batch_per_gpu, self.cfg.hidden_size] assert predicted_tokens_dec.size(0) == global_batch_per_gpu for i in range(num_tokens_to_generate): # No microbatches in decoding. Just the global batch. decoder_seq_length = predicted_tokens_dec.size(1) dec_mask = predicted_tokens_dec != tokenizer.pad_id if encoder_input is not None: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask, encoder_input] else: batch_for_pipeline = [tokens_enc, predicted_tokens_dec, enc_mask, dec_mask] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) else: output_tensor = forward_backward_no_pipelining( forward_step_func=self.get_forward_output_only_func(), batch=batch_for_pipeline, model=self.enc_dec_model, forward_only=True, tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor) log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(output_tensor, dim=-1), dim=-1) predicted_tokens_dec = torch.cat( [predicted_tokens_dec.to(token_ids.device), token_ids[:, -1].unsqueeze(1)], dim=1 ) else: log_probs = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1]), dtype=self.autocast_dtype ).cuda() predicted_tokens_dec = torch.zeros( (predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1), dtype=predicted_tokens_dec.dtype, ).cuda() if self.cfg.get('pipeline_model_parallel_size', 1) > 1: # Broadcast from the last pipeline stage to all other model-parallel ranks. torch.distributed.broadcast( predicted_tokens_dec, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) torch.distributed.broadcast( log_probs, parallel_state.get_pipeline_model_parallel_last_rank(), group=parallel_state.get_model_parallel_group(), ) # Reset microbatch calculator to what it was before decoding. _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), micro_batch_size=global_batch_per_gpu // num_micro_batches_before_decode, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) return predicted_tokens_dec, log_probs
def tab_sample_sequence_batch( model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=True, type_ids=None, temperature=None, ): app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer sizes = tokenizer.code_column.sizes tokens_per_row = sum(sizes) + 1 columns = tokenizer.code_column.columns num_columns = len(columns) tokenid_range = [] for i in range(num_columns): tokenid_range.extend(tokenizer.code_column.get_range(i)) model.eval() with torch.no_grad(): context_length = context_lengths.min().item() context = context_tokens[:, :context_length] # the context may start in the middle of the row, # calculate the offset according to the position of '\n' or '<|endoftext|>' positions = torch.where(context == tokenizer.eor)[1] if len(positions) == 0: positions = torch.where(context == tokenizer.eod)[1] if len(positions) != 0: max_position = positions.max().item() # TODO, need to make sure context of different batch have the same offset lengths") # otherwise, need to calculate offset per batch_id offset = (context_length - max_position - 1) % tokens_per_row else: offset = 0 eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length: maxlen = model.cfg.encoder_seq_length lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) # micro_batch_size = 2 attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() token_in_row = (counter + offset) % tokens_per_row logits = logits.float() logits /= temperature if token_in_row == tokens_per_row - 1: # line break eor_id = tokenizer.eor eod_id = tokenizer.eos_id min_id = min(eor_id, eod_id) max_id = max(eor_id, eod_id) + 1 logits = tab_logits(logits, min_id, max_id) else: # limit the range min_id, max_id = tokenid_range[token_in_row] logits = tab_logits(logits, min_id, max_id) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if output_logits is None: output_context = F.log_softmax( output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output_context, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def sample_sequence_batch( model, context_tokens, context_lengths, task_ids, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None, temperature=None, extra={}, ): # Importing here to avoid circular import errors from nemo.collections.nlp.models.language_modeling import MegatronGPTPromptLearningModel app_state = AppState() micro_batch_size = context_tokens.shape[0] _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, micro_batch_size=micro_batch_size, data_parallel_size=1, ) tokenizer = model.tokenizer model.eval() with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. eod_id = tokenizer.eos_id counter = 0 batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None all_generated_indices = None # used to track all generated indices # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() if maxlen > model.cfg.encoder_seq_length + 1: maxlen = model.cfg.encoder_seq_length + 1 lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: # types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) attention_mask_repeat = torch.concat( [attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()) len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) # Only prompt learning models will have a prompt table, and require task ids if isinstance(model, MegatronGPTPromptLearningModel): batch = [ tokens2use, attention_mask_repeat, positions2use, task_ids, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.frozen_model.cfg.hidden_size ] else: batch = [ tokens2use, attention_mask_repeat, positions2use, setkey_value_array, len_array ] tensor_shape = [ tokens2use.shape[1], micro_batch_size, model.cfg.hidden_size ] output = forward_step(model, batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): output = output[0]['logits'].float() output = tensor_parallel.gather_from_tensor_model_parallel_region( output) assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() # make sure it will generate at least min_length min_length = extra.get('min_tokens_to_generate', 0) if min_length > 0: within_min_length = (context_length - context_lengths) < min_length logits[within_min_length, eod_id] = -float('Inf') # make sure it won't sample outside the vocab_size range logits[:, tokenizer.vocab_size:] = -float('Inf') if extra.get('greedy', False): prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= temperature # handle repetition penality logits = repetition_penalty( logits, extra.get('repetition_penalty', 1.2), all_generated_indices) logits = top_k_logits(logits, top_k=extra.get('top_k', 0), top_p=extra.get('top_p', 0.9)) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the predicted out of vocabulary tokens prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch(tokens[:, context_length].view(-1), prev, started) # Replace sampled tokens w/ done token if EOD has already been sampled new_tokens = switch(new_tokens, eod_id, is_done) # Replace special soft prompt token ids with unk token ids if isinstance(model, MegatronGPTPromptLearningModel): pseudo_token_ids_start = model.pseudo_token_ids_start new_tokens[(new_tokens >= pseudo_token_ids_start)] = tokenizer.unk_id tokens[:, :context_length][( tokens[:, :context_length] >= pseudo_token_ids_start)] = tokenizer.unk_id # Insert either new predicted or next prompt token tokens[:, context_length] = new_tokens if output_logits is None: output = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length + 1], 2) output_logits = torch.gather(output, 2, indices).squeeze(2) all_generated_indices = indices[:, :, 0] if all_probs: full_logits = output else: output = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) new_output_logits = torch.gather(output, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat( [output_logits, new_output_logits], 1) all_generated_indices = torch.cat( [all_generated_indices, indices[:, :, 0]], 1) if all_probs: full_logits = torch.cat([full_logits, output], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eod_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if parallel_state.is_pipeline_first_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank( ) group = parallel_state.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None else: yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break
def complete(self, request: List, positions: List, tokens_to_generate: int): """ Autoregressively invokes language model in the inference mode Args: request: * tokens: List of "buckets" with unpadded tokens of the same length * prompt_tags: List of "buckets" where each bucket contains the prompt_tag strings specifying the prompt tag to use (optional) positions: List with initial prompts positions tokens_to_generate: int value denoting amount of tokens model should generate Returns: response: A python list of tuples (text, tokens, log_probs, offsets) * text: string, inputted prompt + generated text by model * tokens: list of tokens correspond to text * log_probs: list of tokens log probabilities * offsets: list of tokens start positions in text """ results = [] request_tokens = request["tokens"] for idx, tokens in enumerate(request_tokens): # For prompt tuned GPT models if self.use_soft_prompts: prompt_tags = request["prompt_tags"][idx] else: prompt_tags = None logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1) for i in range(tokens_to_generate + 1): if self.use_soft_prompts: batch_size = len(tokens) full_length = len(tokens[0]) + self.num_prompt_tokens # Get postion ids for text after soft prompt position_ids = torch.arange(start=self.num_prompt_tokens, end=full_length, dtype=torch.long, device=self.device) position_ids = position_ids.unsqueeze(0).expand_as( tokens).clone() # Make attention mask starting with first token in soft prompt attention_mask = torch.tril( torch.ones( (batch_size, full_length, full_length), device=self.device)).view(batch_size, 1, full_length, full_length) attention_mask = attention_mask < 0.5 else: attention_mask, _, position_ids = get_ltor_masks_and_position_ids( data=tokens, eod_token=self.tokenizer.eos_id, reset_position_ids=self.cfg.get( 'reset_position_ids', False), reset_attention_mask=self.cfg.get( 'reset_attention_mask', False), eod_mask_loss=self.cfg.get('eod_mask_loss', False), ) # No labels during inference. Still need masks to not attend to the right output_tensor = self(tokens, position_ids, attention_mask, prompt_tags=prompt_tags, labels=None) output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( output_tensor) log_probs, token_ids = torch.max( logsoftmaxlayer(output_tensor), dim=-1) reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id tokens = torch.cat( [tokens, torch.unsqueeze(token_ids[:, -1], 1)], dim=1) # add to results as (text, tokens, log_probs, offsets) for token, prob in zip(tokens, log_probs.tolist()): results.append( (self.tokenizer.ids_to_text(token[:-1]), self.tokenizer.ids_to_tokens(token[:-1]), prob, [0])) # offsets calculation for item in results: for index, token in enumerate(item[1]): if index != len(item[1]) - 1: item[3].append(len(token) + item[3][-1]) # returnprompts in order they were inputted response = [0 for i in range(len(positions))] for item, index in zip(results, positions): response[index] = item return response
def complete(self, request: Dict): """ Autoregressively invokes language model in the inference mode Args: request: Dictionary with the following fields * prompt: a string which text the model should complete. * tokens_to_generate: how many tokens to generate while doing prompt completion. * stop_after_sentence: (default True) whether to stop generation once sentence end is reached. Returns: response: A python dictionary with the following fields * prompt: original text of the prompt * tokenized_prompt: list of (str) tokens from prompt * completion: a python dictionary with the following subfields: * tokens: a list of triples (token, token_id, log_prob) comprising completion * stop reason: either 'eos', 'sentence_end' or 'limit' indicating why generation stopped * text: completion text (as a single string) """ response = {} self.freeze() logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1) response['tokenized_prompt'] = request['tokenized_prompt'] tokens = request['tokens'] # naive greedy slow loop # TODO: add option for BeamSearchDecoder response['prompt'] = request['prompt'] response['completion'] = {} response['completion']['stop reason'] = 'limit' for i in range(request.get("tokens_to_generate", 64)): attention_mask, _, position_ids = get_ltor_masks_and_position_ids( data=tokens, eod_token=self.tokenizer.eos_id, reset_position_ids=self.cfg.get('reset_position_ids', False), reset_attention_mask=self.cfg.get('reset_attention_mask', False), eod_mask_loss=self.cfg.get('eod_mask_loss', False), ) # No labels during inference. Still need masks to not attend to the right output_tensor = self(tokens, position_ids, attention_mask, labels=None) output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( output_tensor) log_probs, token_ids = torch.max(logsoftmaxlayer(output_tensor), dim=-1) reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id tokens = torch.cat([torch.squeeze(tokens), token_ids[:, -1]]) response['completion']["tokens"] = list( zip(self.tokenizer.ids_to_tokens(tokens), tokens.tolist(), log_probs.tolist()[0])) completion_text = self.tokenizer.ids_to_text( x[1] for x in response['completion']["tokens"]) if reached_eos: # Will it actually ever reach that? response['completion']['stop reason'] = 'eos' break elif request.get("stop_after_sentence", True) and completion_text.endswith( ('.', '!', '?')): response['completion']['stop reason'] = 'sentence_end' break tokens = torch.unsqueeze(tokens, 0) response['completion']["text"] = self.tokenizer.ids_to_text( x[1] for x in response['completion']["tokens"]) self.unfreeze() return response