def generate_samples(model, tokenizer, args): print (f"generate_samples was called with model {model} \n and tokenizer {tokenizer}") model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 print (f"terminate_runs = {terminate_runs}") if mpu.get_model_parallel_rank() == 0: print ("get_model_parallel_rank() was 0") # raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = "localStorage.getItem(" while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer(raw_text)['input_ids'] context_length = len(context_tokens) if context_length >= args.seq_length // 2: print("\nContext length", context_length, "\nPlease give smaller context (half of the sequence length)!") continue else: print (f"get_model_parallel_rank() was NOT 0 but {mpu.get_model_parallel_rank()}") _ = tokenizer("EMPTY TEXT")['input_ids'] terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return start_time = time.time() print ("generating...") generated = generate( model, tokenizer, raw_text, out_seq_length=args.out_seq_length, seq_length=args.seq_length, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p ) if mpu.get_model_parallel_rank() == 0: print ("We should clear the terminal and print results...") os.system('clear') print("\nTime taken: {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) print("\nGPT:", generated, flush=True) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group())
def generate_samples(model, tokenizer, args): model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer(raw_text)['input_ids'] context_length = len(context_tokens) if context_length >= args.seq_length // 2: print( "\nContext length", context_length, "\nPlease give smaller context (half of the sequence length)!" ) continue else: _ = tokenizer("EMPTY TEXT")['input_ids'] terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return start_time = time.time() generated = generate(model, tokenizer, raw_text, out_seq_length=args.out_seq_length, seq_length=args.seq_length, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p) if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) print("\nGPT:", generated, flush=True) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group())
def get_train_val_test_data(args): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: (train_data, val_data, test_data), num_tokens, eod_token, tokenizer = make_gpt3_dataloaders(args) before = num_tokens after = before multiple = args.make_vocab_size_divisible_by * mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0( '> padded vocab (size: {}) with {} dummy tokens (new size: {})'.format(before, after - before, after)) print_rank_0('> end-of-document token: {}'.format(eod_token)) token_counts = torch.cuda.LongTensor( [after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)]) else: tokenizer = None token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_tokens = token_counts[0].item() eod_token = token_counts[1].item() args.do_train = token_counts[2].item() args.do_valid = token_counts[3].item() args.do_test = token_counts[4].item() return train_data, val_data, test_data, num_tokens, eod_token, tokenizer
def generate(model, tokenizer, raw_text, out_seq_length=256, seq_length=512, temperature=1.0, top_k=0, top_p=0.9): context_tokens = tokenizer(raw_text)['input_ids'] context_length = len(context_tokens) pad_id = tokenizer.encoder['<pad>'] if context_length < seq_length: context_tokens.extend([pad_id] * (seq_length - context_length)) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor[0].item() tokens = context_tokens_tensor tokens = tokens.view(1, -1).contiguous() tokens = tokens.to(torch.cuda.current_device()) attention_mask, loss_mask, position_ids = get_masks_and_position_ids(tokens, pad_id, False, False) counter = 0 start_context_length = context_length while counter < (start_context_length + out_seq_length): logits = model(tokens, position_ids, attention_mask) logits = logits[:, context_length - 1, :] / temperature logits = top_k_logits(logits, top_k=top_k, top_p=top_p) log_probs = torch.nn.functional.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1) tokens[0, context_length] = prev[0] context_length += 1 if context_length >= seq_length: break counter += 1 output_tokens_list = tokens.view(-1).tolist() decode_tokens = tokenizer.decode(output_tokens_list) decode_tokens = decode_tokens[:decode_tokens.find("<|endoftext|>")] token_end = decode_tokens.find("<|endoftext|>") if token_end != -1: break output_tokens_list = tokens.view(-1).tolist() decode_tokens = tokenizer.decode(output_tokens_list) return decode_tokens[:decode_tokens.find("<|endoftext|>")]
def has_overflow(self, params): overflow = self.has_overflow_serial(params) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs overflow_gpu = torch.cuda.ByteTensor([overflow]) torch.distributed.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) overflow = overflow_gpu[0].item() return bool(overflow)
def __call__(self, text=None, input_ids=None, labels=None, **kwargs): if input_ids is None: if text is None: text = "" input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']]) if isinstance(input_ids, list): input_ids = torch.cuda.LongTensor(input_ids) if isinstance(labels, list): labels = torch.cuda.LongTensor(labels) res = [] if labels is not None: lbls = labels else: lbls = [None] * len(input_ids) loss = None original_context_length = 0 seq_len = self.seq_len for tokens, lbl in zip(input_ids, lbls): context_tokens = tokens.tolist() context_length = len(context_tokens) original_context_length = len(context_tokens) while context_length > seq_len: seq_len += 16 if context_length < seq_len: context_tokens.extend([self.pad_token_id] * (seq_len - context_length)) if labels is not None: lbl = lbl.tolist() lbl.extend([self.pad_token_id] * (seq_len - context_length)) lbl = torch.cuda.LongTensor(lbl) if context_length > 2048: context_tokens = context_tokens[-2048:] if labels is not None: lbl = lbl.tolist()[-2048:] lbl = torch.cuda.LongTensor(lbl) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) # context_length = context_length_tensor[0].item() tokens = context_tokens_tensor tokens = tokens.view(1, -1).contiguous() tokens = tokens.to(torch.cuda.current_device()) attention_mask, loss_mask, position_ids = get_masks_and_position_ids(tokens, self.pad_token_id, False, False) lm_logits = self.model(tokens, position_ids, attention_mask) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = lbl[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=self.pad_token_id) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) res.append((lm_logits, loss)) logits = torch.cat([x[0] for x in res], dim=0)[:, : original_context_length, :] if loss is not None: loss = [x[1] for x in res] return ModelOutput(logits, loss)