def get_batch(data_iterator): """Generate a batch""" args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids
def detokenize_generations(tokens_gpu_tensor, lengths_gpu_tensor, return_segments): """Detokenize the generated tokens.""" tokenizer = get_tokenizer() prompts_plus_generations = [] if return_segments: prompts_plus_generations_segments = [] tokens = tokens_gpu_tensor.cpu().numpy().tolist() lengths = lengths_gpu_tensor.cpu().numpy().tolist() for sequence_tokens, length in zip(tokens, lengths): sequence_tokens = sequence_tokens[:length] prompts_plus_generations.append( tokenizer.detokenize(sequence_tokens)) if return_segments: words = [] for token in sequence_tokens: word = tokenizer.tokenizer.decoder[token] word = bytearray( [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( 'utf-8', errors='replace') words.append(word) prompts_plus_generations_segments.append(words) if return_segments: return tokens, prompts_plus_generations, \ prompts_plus_generations_segments return tokens, prompts_plus_generations
def __init__(self, num_tokentypes=2, parallel_output=True, pre_process=True, post_process=True): super(PretrainedBertModel, self).__init__() args = get_args() tokenizer = get_tokenizer() self.pad_id = tokenizer.pad self.biencoder_projection_dim = args.biencoder_projection_dim self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method, pre_process=self.pre_process, post_process=self.post_process) if args.biencoder_projection_dim > 0: self.projection_enc = get_linear_layer( args.hidden_size, args.biencoder_projection_dim, init_method) self._projection_enc_key = 'projection_enc'
def generate_samples_unconditional(model): args = get_args() tokenizer = get_tokenizer() num_samples = args.num_samples context_tokens = [[tokenizer.eod] for _ in range(args.batch_size)] ctr = 0 while True: start_time = time.time() for token_stream in get_token_stream(model, copy.deepcopy(context_tokens)): pass if ctr % args.log_interval == 0: print('Avg s/batch:', (time.time() - start_time) / min(args.log_interval, ctr + 1)) start_time = time.time() length = len(token_stream) token_batch = token_stream[0].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist() for tokens, length in zip(token_batch, length_batch): tokens = tokens[1:length - 1] text = tokenizer.detokenize(tokens) is_finished = length < args.seq_length - 1 datum = {'text': text, 'length': length - 1, 'finished': is_finished} yield datum ctr += 1 if ctr >= num_samples: break if ctr >= num_samples: break
def __init__(self, name, block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, query_in_block_prob, seed, use_titles=True, use_one_sent_docs=False, binary_head=False): self.name = name self.seed = seed self.max_seq_length = max_seq_length self.query_in_block_prob = query_in_block_prob self.block_dataset = block_dataset self.title_dataset = title_dataset self.rng = random.Random(self.seed) self.use_titles = use_titles self.use_one_sent_docs = use_one_sent_docs self.samples_mapping = get_block_samples_mapping( block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, seed, name, use_one_sent_docs) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab self.cls_id = self.tokenizer.cls self.sep_id = self.tokenizer.sep self.mask_id = self.tokenizer.mask self.pad_id = self.tokenizer.pad
def __init__(self, name, data_prefix, documents, indexed_dataset, num_samples, seq_length, masked_lm_prob, seed): super().__init__() self.name = name self.seed = seed self.indexed_dataset = indexed_dataset self.seq_length = seq_length self.masked_lm_prob = masked_lm_prob self.seq_length = seq_length # Checks assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( self.name, data_prefix, documents, self.indexed_dataset.sizes, num_samples, seq_length, seed) # Vocab stuff. tokenizer = get_tokenizer() self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_to_token_dict = tokenizer.inv_vocab self.cls_id = tokenizer.cls self.sep_id = tokenizer.sep self.mask_id = tokenizer.mask self.pad_id = tokenizer.pad
def get_token_stream(model, context_tokens): args = get_args() tokenizer = get_tokenizer() context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids) for tokens, lengths in batch_token_iterator: context_length += 1 yield tokens[:, :context_length], lengths
def generate_samples_eval(model, context, max_gen_length, eos_token_id): # Generate samples for lm evaluation # NEED TO THINK ABOUT eos token args = get_args() tokenizer = get_tokenizer() raw_text_len = len(context) model.eval() context_tokens = tokenizer.tokenize(context) args.out_seq_length = max_gen_length + len(context_tokens) args.eos_id = eos_token_id with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter == args.out_seq_length: break decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:] return trim_decode_tokens
def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed): # Params to store. self.name = name self.seed = seed self.masked_lm_prob = masked_lm_prob self.max_seq_length = max_seq_length # Dataset. self.indexed_dataset = indexed_dataset # Build the samples mapping. self.samples_mapping = get_samples_mapping_(self.indexed_dataset, data_prefix, num_epochs, max_num_samples, self.max_seq_length, short_seq_prob, self.seed, self.name) # Vocab stuff. tokenizer = get_tokenizer() self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_to_token_dict = tokenizer.inv_vocab self.cls_id = tokenizer.cls self.sep_id = tokenizer.sep self.mask_id = tokenizer.mask self.pad_id = tokenizer.pad
def get_batch(data_iterator): """Build the batch.""" tokenizer = get_tokenizer() # Items and their type. keys = ['text', 'labels', 'loss_mask', 'padding_mask'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens = data_b['text'].long() loss_mask = data_b['loss_mask'].float() lm_labels = data_b['labels'].long() padding_mask = data_b['padding_mask'].long() # Get the masks and postition ids. attention_mask, position_ids = get_tape_masks_and_position_ids( tokens, tokenizer.cls, reset_position_ids=True, reset_attention_mask=True) return tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids
def get_batch_pipe(data): """A modification of get_batch() to work with the latest batch instead of an iterator. """ args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) # unpack data if args.fp16: # cast to fp16 because pipeline parallelism skips the FP16 wrapper. return fp32_to_fp16( (tokens, position_ids, attention_mask)), fp32_to_fp16( (labels, loss_mask)) else: return (tokens, position_ids, attention_mask), (labels, loss_mask)
def get_nq_dataset(qa_data, split): args = get_args() tokenizer = get_tokenizer() dataset = NQDataset('Google NQ {} Split'.format(split), 'Google Natural Questions', qa_data, tokenizer, args.retriever_seq_length) return dataset
def get_open_retrieval_wiki_dataset(): args = get_args() tokenizer = get_tokenizer() dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase', 'evidence', args.evidence_data_path, tokenizer, args.retriever_seq_length) return dataset
def metrics_func_provider(): """Privde metrics callback function.""" args = get_args() tokenizer = get_tokenizer() def single_dataset_provider(datapath): name = datapath.split('RACE')[-1].strip('/').replace('/', '-') return RaceDataset(name, [datapath], tokenizer, args.seq_length) return accuracy_func_provider(single_dataset_provider)
def single_dataset_provider(datapath): args = get_args() tokenizer = get_tokenizer() name = datapath[0].split('/')[-1].split('.')[0] return Dataset(name, datapath, tokenizer, args.retriever_seq_length, evaluate=True)
def _build_lambada_dataset(): """Build lambada dataset.""" args = get_args() tokenizer = get_tokenizer() assert len(args.valid_data) == 1 val_dataset = _LambadaDataset(args.valid_data[0], tokenizer.eod, tokenizer, args.seq_length, args.strict_lambada) print_rank_0(' > found {} samples.'.format(len(val_dataset))) return val_dataset
def train_valid_datasets_provider(): """Build train and validation dataset.""" args = get_args() tokenizer = get_tokenizer() train_dataset = Dataset('training', args.train_data, tokenizer, args.seq_length) valid_dataset = Dataset('validation', args.valid_data, tokenizer, args.seq_length) return train_dataset, valid_dataset
def get_batch(context_tokens): """Generate batch from context tokens.""" args = get_args() tokenizer = get_tokenizer() # Move to GPU. tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda() # Get the attention mask and postition ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, attention_mask, position_ids
def __init__( self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, max_seq_length_dec, short_seq_prob, seed, ): # Params to store. self.name = name self.seed = seed self.masked_lm_prob = masked_lm_prob self.max_seq_length = max_seq_length self.max_seq_length_dec = max_seq_length_dec # Dataset. self.indexed_dataset = indexed_dataset # Build the samples mapping. self.samples_mapping = get_samples_mapping( self.indexed_dataset, data_prefix, num_epochs, max_num_samples, self.max_seq_length - 2, # account for added tokens short_seq_prob, self.seed, self.name, False, ) # Vocab stuff. tokenizer = get_tokenizer() self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_to_token_dict = tokenizer.inv_vocab self.cls_id = tokenizer.cls self.sep_id = tokenizer.sep self.mask_id = tokenizer.mask self.pad_id = tokenizer.pad self.bos_id = tokenizer.bos_token_id self.eos_id = tokenizer.eos_token_id self.sentinel_tokens = tokenizer.additional_special_tokens_ids assert len( self.sentinel_tokens ) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
def train_valid_datasets_provider(): """Build train and validation dataset.""" args = get_args() tokenizer = get_tokenizer() train_dataset = Dataset('training', args.train_data, tokenizer, args.retriever_seq_length, evaluate=False) valid_dataset = Dataset('validation', args.valid_data, tokenizer, args.retriever_seq_length, evaluate=True) return train_dataset, valid_dataset
def process_batch(batch): """Process batch and produce inputs for the model.""" args = get_args() tokenizer = get_tokenizer() loss_mask = batch['pad_mask'].long().cuda().contiguous().byte() tokens_ = batch['text'].long().cuda().contiguous() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, labels, attention_mask, position_ids, loss_mask
def cross_entropy_forward_step(batch, model): """Simple forward step with cross-entropy loss.""" timers = get_timers() tokenizer = get_tokenizer() # Get the batch. timers('batch generator').start() try: batch_ = next(batch) except BaseException: batch_ = batch group, rank, world_size = get_group_world_size_rank() query_tokens, query_mask, query_types, query_pad_mask, \ context_tokens, context_mask, context_types, context_pad_mask, \ neg_context_tokens, neg_context_mask, neg_context_types, \ reference = process_batch(batch_) timers('batch generator').stop() local_batch_size = query_tokens.shape[0] # Text representation of query and context query_list, context_list = [], [] for i in range(local_batch_size): query_list.append(tokenizer.decode(query_tokens[i].tolist())) context_list.append(tokenizer.decode(context_tokens[i].tolist())) if neg_context_tokens is not None: neg_context_tokens = check_and_append_tensor_for_gather( group, rank, world_size, neg_context_tokens) neg_context_mask = check_and_append_tensor_for_gather( group, rank, world_size, neg_context_mask) neg_context_types = check_and_append_tensor_for_gather( group, rank, world_size, neg_context_types) if neg_context_tokens is not None: context_tokens = torch.cat([context_tokens, neg_context_tokens]) context_mask = torch.cat([context_mask, neg_context_mask]) context_types = torch.cat([context_types, neg_context_types]) # Forward model. output_tensor = model(query_tokens, query_mask, query_types, context_tokens, context_mask, context_types) return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
def _build_wikitext103_dataset(): """""" args = get_args() tokenizer = get_tokenizer() assert len(args.valid_data) == 1 with open(args.valid_data[0], "rb") as reader: entire_data = reader.read().decode('utf-8') num_original_tokens = len(entire_data.strip().split(" ")) entire_data = get_detokenizer(args.valid_data[0])(entire_data) tokenized_data = tokenizer.tokenize(entire_data) num_tokenized_tokens = len(tokenized_data) val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod, num_original_tokens, num_tokenized_tokens, args.overlapping_eval) print_rank_0(' > number of original tokens: {}, number of detokenized ' 'tokens: {}'.format(num_original_tokens, num_tokenized_tokens)) return val_dataset
def get_token_stream(model, context_tokens): args = get_args() tokenizer = get_tokenizer() context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids) for tokens, lengths in batch_token_iterator: context_length += 1 if tokens is not None: yield tokens[:, :context_length], lengths else: yield None, None
def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS): """Given a set of prompts and number of tokens to generate: - tokenize prompts - set the sequence length to be the max of length of prompts plus the number of tokens we would like to generate - pad all the sequences to this length so we can convert them into a 2D tensor. """ # Tokenize all the prompts. tokenizer = get_tokenizer() if add_BOS: prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt) for prompt in prompts] else: prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts] # Now we have a list of list of tokens which each list has a different # size. We want to extend this list to: # - incorporate the tokens that need to be generated # - make all the sequences equal length. # Get the prompts length. prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens] # Get the max prompts length. max_prompt_len = max(prompts_length) # Number of tokens in the each sample of the batch. samples_length = max_prompt_len + tokens_to_generate # Now update the list of list to be of the same size: samples_length. for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length): padding_size = samples_length - prompt_length prompt_tokens.extend([tokenizer.eod] * padding_size) # Now we are in a structured format, we can convert to tensors. prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) prompts_length_tensor = torch.cuda.LongTensor(prompts_length) return prompts_tokens_tensor, prompts_length_tensor
def generate_samples_interactive(model, print_frequency=24): args = get_args() tokenizer = get_tokenizer() context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: os.system('clear') raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0 and \ counter % print_frequency == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1 if mpu.get_model_parallel_rank() == 0: input("\nPress any key to continue >>>")
def single_dataset_provider(datapath): args = get_args() tokenizer = get_tokenizer() name = name_from_datapath_func(datapath) return Dataset(name, [datapath], tokenizer, args.seq_length)
def generate_samples_input_from_file(model): args = get_args() tokenizer = get_tokenizer() # Read the sample file and open the output file. assert args.sample_input_file is not None, \ 'sample input file is not provided.' if mpu.get_model_parallel_rank() == 0: fname = open(args.sample_input_file, "r") all_raw_text = fname.readlines() input_count = len(all_raw_text) input_pos = 0 if args.sample_output_file is None: sample_output_file = args.sample_input_file + ".out" print('could not find `sample-output-file`, setting ' 'it to {}'.format(sample_output_file)) else: sample_output_file = args.sample_output_file fname_out = open(sample_output_file, "w+") context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = all_raw_text[input_pos] input_pos += 1 if input_pos == input_count: raw_text = "stop" if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return token_stream = get_token_stream(model, [context_tokens]) for _, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) fname_out.write("\nContext:") fname_out.write(raw_text) fname_out.write("\n\nMegatron-LM:") fname_out.write(trim_decode_tokens) fname_out.write("\n") raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen=None, type_ids=None): args = get_args() tokenizer = get_tokenizer() model.eval() with torch.no_grad(): context_length = context_lengths.min().item() eos_id = tokenizer.eod counter = 0 org_context_length = context_length layer_past = None batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens if maxlen is None: maxlen = args.seq_length - 1 if maxlen > (org_context_length + args.out_seq_length): maxlen = org_context_length + args.out_seq_length lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length <= (maxlen): if args.recompute: logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids, forward_method_parallel_output=False) logits = logits[:, context_length - 1, :] else: types2use = None if counter == 0: tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] if type_ids is not None: types2use = type_ids[:, :context_length] else: tokens2use = tokens[:, context_length - 1].view( batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use, forward_method_parallel_output=False) logits = logits[:, -1].view(batch_size, -1).contiguous() if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) print_logits = [] for p in prev: print_logits.append([logits[i, p].item() for i in range(batch_size)]) started = context_lengths <= context_length tokens[:, context_length] = switch( tokens[:, context_length].view(-1), prev, started) context_length += 1 counter += 1 done_token = (prev == eos_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) yield tokens, lengths if done: break
def generate_samples_input_from_file(model): args = get_args() tokenizer = get_tokenizer() # Read the sample file and open the output file. assert args.sample_input_file is not None, \ 'sample input file is not provided.' if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank( ) == 0: fname = open(args.sample_input_file, "r") all_raw_text = fname.readlines() input_count = len(all_raw_text) input_pos = 0 if args.sample_output_file is None: sample_output_file = args.sample_input_file + ".out" print('`sample-output-file` not specified, setting ' 'it to {}'.format(sample_output_file)) else: sample_output_file = args.sample_output_file fname_out = open(sample_output_file, "w+") context_count = 0 model.eval() with torch.no_grad(): while True: terminate_runs = 0 raw_text_len = 0 if mpu.is_pipeline_first_stage() \ and mpu.get_tensor_model_parallel_rank() == 0: raw_text = all_raw_text[input_pos] input_pos += 1 if input_pos == input_count: raw_text = "stop" raw_text_len = len(raw_text) if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = 0 input_info = [terminate_runs, raw_text_len, context_length] input_info_tensor = torch.cuda.LongTensor(input_info) torch.distributed.all_reduce(input_info_tensor, group=mpu.get_model_parallel_group()) terminate_runs = input_info_tensor[0].item() raw_text_len = input_info_tensor[1].item() context_length = input_info_tensor[2].item() if terminate_runs == 1: return # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if mpu.get_tensor_model_parallel_rank() == 0 \ and args.pipeline_model_parallel_size > 1: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.cuda.LongTensor( context_tokens) torch.distributed.broadcast(context_tokens_tensor, src, group) else: src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.empty( context_length, dtype=torch.int64, device=torch.device("cuda")) torch.distributed.broadcast(context_tokens_tensor, src, group) context_tokens = context_tokens_tensor.cpu().numpy( ).tolist() token_stream = get_token_stream(model, [context_tokens]) for _, decode_tokens in enumerate(token_stream): pass if mpu.get_tensor_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage(): os.system('clear') print("\nContext:", raw_text, flush=True) fname_out.write("\nContext:") fname_out.write(raw_text) decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] print("\nMegatron-LM:", trim_decode_tokens, flush=True) fname_out.write("\n\nMegatron-LM:") fname_out.write(trim_decode_tokens) fname_out.write("\n") raw_text = None context_count += 1