def get_token_stream(model, context_tokens, tokenizer, args): pad_id = tokenizer.get_command('pad').Id # context_length = len(context_tokens) # if context_length < args.seq_length: # context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length) context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) # context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) counter = 0 org_context_length = context_length layer_past = None batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args) for tokens, lengths in batch_token_iterator: context_length += 1 yield tokens[:, :context_length], lengths
def read_context(tokenizer, args, output): terminate_runs, skip_run = 0, 0 if mpu.get_model_parallel_rank() == 0: while True: raw_text = input("\nContext prompt (stop to exit) >>> ") if not raw_text: print('Prompt should not be empty!') continue if raw_text == "stop": terminate_runs = 1 break generation_mask = '[gMASK]' if args.task_mask else '[MASK]' if args.block_lm and 'MASK]' not in raw_text: raw_text += ' ' + generation_mask output.write(raw_text) context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization if args.block_lm: context_tokens = [tokenizer.get_command('ENC').Id ] + context_tokens if not raw_text.endswith('MASK]'): context_tokens = context_tokens + [ tokenizer.get_command('eos').Id ] context_length = len(context_tokens) if context_length >= args.seq_length: print("\nContext length", context_length, "\nPlease give smaller context than the window length!") continue break else: context_length = 0 terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return terminate_runs, None, None, None context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor[0].item() if mpu.get_model_parallel_rank() == 0: context_tokens_tensor = torch.cuda.LongTensor(context_tokens) else: context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) if mpu.get_model_parallel_rank() != 0: raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) return terminate_runs, raw_text, context_tokens_tensor, context_length
def get_train_val_test_data(args, tokenizer): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: data_config = configure_data() if args.block_lm: data_set_type = "Block" elif args.transformer_xl: data_set_type = "GPT-XL" else: data_set_type = "GPT2" data_config.set_defaults(data_set_type=data_set_type, transpose=False) train_data, val_data, test_data = data_config.apply(args, tokenizer) data_counts = torch.cuda.LongTensor( [int(args.do_train), int(args.do_valid), int(args.do_test)]) else: data_counts = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(data_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = data_counts[0].item() args.do_valid = data_counts[1].item() args.do_test = data_counts[2].item() return train_data, val_data, test_data
def prepare_tokenizer(args): add_sentinel_token = 0 if args.sentinel_token: add_sentinel_token = args.max_position_embeddings tokenizer = make_tokenizer(args.tokenizer_type, None, args.tokenizer_path, args.vocab_size, args.tokenizer_model_type, add_block_symbols=args.block_lm, cache_dir=args.cache_dir, add_sentinel_token=add_sentinel_token, add_task_mask=args.task_mask, add_decoder_mask=args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0, fix_command_token=args.fix_command_token) if mpu.get_model_parallel_rank() == 0: num_tokens = tokenizer.num_tokens eod_token = tokenizer.get_command('eos').Id assert eod_token == tokenizer.get_command('pad').Id before = num_tokens after = before multiple = args.make_vocab_size_divisible_by while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format(before, after - before, after)) print_rank_0('> found end-of-document token: {}'.format(eod_token)) token_counts = torch.cuda.LongTensor([after, eod_token]) else: token_counts = torch.cuda.LongTensor([0, 0]) # Broadcast num tokens. torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_tokens = token_counts[0].item() eod_token = token_counts[1].item() args.vocab_size, args.eod_token = num_tokens, eod_token return tokenizer
def read_context(tokenizer, args, output): terminate_runs, skip_run = 0, 0 if mpu.get_model_parallel_rank() == 0: while True: raw_text = input("\nContext prompt (stop to exit) >>> ") if not raw_text: print('Prompt should not be empty!') continue if raw_text == "stop": terminate_runs = 1 break if args.hierarchical: raw_text = "Summary: " + raw_text output.write(raw_text) context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_length = len(context_tokens) if context_length >= args.seq_length: print("\nContext length", context_length, "\nPlease give smaller context than the window length!") continue break else: context_length = 0 terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return terminate_runs, raw_text, None, None context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor[0].item() if mpu.get_model_parallel_rank() == 0: context_tokens_tensor = torch.cuda.LongTensor(context_tokens) else: context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) return terminate_runs, raw_text, context_tokens_tensor, context_length
def backward(ctx, grad_x): rank = torch.distributed.get_rank() prev_dst_rank = mpu.get_model_parallel_prev_dst_rank() src_rank = mpu.get_model_parallel_src_rank() pipeline_group = mpu.get_pipeline_parallel_succ_group() if prev_dst_rank is not None and rank == src_rank: assert pipeline_group is not None dist.broadcast(grad_x, src_rank, group=pipeline_group) return None
def evaluate(data_loader, model, args, timers, num_iterations=None): """Evaluation.""" # Turn on evaluation mode which disables dropout. model.eval() total_lm_loss = 0 if num_iterations is not None: max_iters = num_iterations else: if mpu.get_model_parallel_rank() == 0: max_iters_gpu = torch.cuda.LongTensor([len(data_loader)]) else: max_iters_gpu = torch.cuda.LongTensor([0]) torch.distributed.broadcast(max_iters_gpu, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) max_iters = max_iters_gpu[0].item() print_rank_0('global rank: {} | max iters: {}'.format( torch.distributed.get_rank(), max_iters)) if data_loader is not None: data_iterator = iter(data_loader) else: data_iterator = None with torch.no_grad(): iteration = 0 while iteration < max_iters: if iteration % args.log_interval == 0: print_rank_0('global rank: {} | iteration: {}'.format( torch.distributed.get_rank(), iteration)) # Forward evaluation. lm_loss = forward_step(data_iterator, model, args, timers) if lm_loss is None: break # Reduce across processes. if isinstance(model, DDP): torch.distributed.all_reduce(lm_loss.data) if args.cloze_eval: lm_loss.data = lm_loss.data / args.world_size else: lm_loss.data = lm_loss.data / args.model_parallel_size if not args.cloze_eval: total_lm_loss += lm_loss.data.detach().float().item()/(args.num_tokenized_tokens-1) else: total_lm_loss += lm_loss.data.detach().float().item() iteration += 1 # Move model back to the train mode. model.train() return total_lm_loss
def forward(ctx, x): rank = torch.distributed.get_rank() prev_dst_rank = mpu.get_model_parallel_prev_dst_rank() src_rank = mpu.get_model_parallel_src_rank() pipeline_group = mpu.get_pipeline_parallel_succ_group() model_parallel_group = mpu.get_model_parallel_group() if prev_dst_rank is not None: if rank == src_rank: assert pipeline_group is not None dist.broadcast(x, prev_dst_rank, group=pipeline_group) dist.broadcast(x, src_rank, group=model_parallel_group) return x
def get_train_val_test_data(args): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: if args.use_npy_data_loader: (train_data, val_data, test_data), num_tokens, \ eod_token = make_gpt2_dataloaders(args) else: data_config = configure_data() data_config.set_defaults(data_set_type='GPT2', transpose=False) (train_data, val_data, test_data), tokenizer = data_config.apply(args) num_tokens = tokenizer.num_tokens eod_token = tokenizer.get_command('eos').Id assert eod_token == tokenizer.get_command('pad').Id before = num_tokens after = before multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format(before, after - before, after)) print_rank_0('> found end-of-document token: {}'.format(eod_token)) token_counts = torch.cuda.LongTensor([ after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test) ]) else: token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_tokens = token_counts[0].item() eod_token = token_counts[1].item() args.do_train = token_counts[2].item() args.do_valid = token_counts[3].item() args.do_test = token_counts[4].item() return train_data, val_data, test_data, num_tokens, eod_token
def mix_forward_step(batch_and_dataloader, model, args, times, mems): use_blocklm = 0 if args.block_lm_ratio > 0.0: if mpu.get_model_parallel_rank() == 0: if random.random() > 1 / (1 + args.block_lm_ratio): use_blocklm = 1 use_blocklm = torch.cuda.LongTensor([use_blocklm]) torch.distributed.broadcast(use_blocklm, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) use_blocklm = use_blocklm.item() if use_blocklm: return lm_forward_step((batch_and_dataloader[1], None), model, args, times, mems) else: return finetune_forward_step(batch_and_dataloader[0], model, args, times, mems)
def get_train_val_test_data(args): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: data_config = configure_data() ds_type = 'BERT' data_config.set_defaults(data_set_type=ds_type, transpose=False) (train_data, val_data, test_data), tokenizer = data_config.apply(args) before = tokenizer.num_tokens after = before multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format(before, after - before, after)) # Need to broadcast num_tokens and num_type_tokens. token_counts = torch.cuda.LongTensor([ after, tokenizer.num_type_tokens, int(args.do_train), int(args.do_valid), int(args.do_test) ]) else: token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_tokens = token_counts[0].item() num_type_tokens = token_counts[1].item() args.do_train = token_counts[2].item() args.do_valid = token_counts[3].item() args.do_test = token_counts[4].item() return train_data, val_data, test_data, num_tokens, num_type_tokens
def test_get_model_parallel_src_rank(model_parallel_size_): if torch.distributed.get_rank() == 0: print('> testing get_model_parallel_src_rank with size {} ...'.format( model_parallel_size_)) model_parallel_size = min(model_parallel_size_, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size) assert mpu.model_parallel_is_initialized() # Checks src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() assert mpu.get_model_parallel_src_rank() == src_rank # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def get_eval_data(args): val_dataloader = None if mpu.get_model_parallel_rank() == 0: eval_batch_size = args.eval_batch_size eval_batch_size = args.batch_size if eval_batch_size is None else eval_batch_size seq_len = args.seq_length valid_data = args.valid_data valid_data = valid_data[0] if isinstance(valid_data, list) else valid_data tokenizer = get_tokenizer(args) if not args.cloze_eval: with open(valid_data, "rb") as reader: entire_data = reader.read().decode('utf-8') num_original_tokens = len(entire_data.strip().split(" ")) entire_data = get_detokenizer(valid_data)(entire_data) tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization num_tokenized_tokens = len(tokenized_data) string = 'Original Tokens: %d, Detokenized tokens: %d' % (num_tokenized_tokens, num_original_tokens) print_rank_0(string) eod_token = tokenizer.get_command('pad').Id val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token, args.overlapping_eval) else: val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len) num_tokenized_tokens = 0 num_original_tokens = 0 val_dataloader = torch.utils.data.DataLoader( val_dataset, batch_size=eval_batch_size, drop_last=False) before = tokenizer.num_tokens after = before while after % mpu.get_model_parallel_world_size() != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'. format(before, after - before, after)) eod_token = tokenizer.get_command('pad').Id num_examples = len(val_dataset) token_counts = torch.cuda.LongTensor([after, eod_token, num_examples, num_original_tokens, num_tokenized_tokens]) else: token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.vocab_size = token_counts[0].item() args.eod_token = token_counts[1].item() args.num_examples = token_counts[2].item() args.num_original_tokens = token_counts[3].item() args.num_tokenized_tokens = token_counts[4].item() print('global rank: {} | vocab size: {} | eod token: {} | ' 'num_examples: {} | num_original_tokens: {} | ' 'num_tokenized_tokens: {}'.format( torch.distributed.get_rank(), args.vocab_size, args.eod_token, args.num_examples, args.num_original_tokens, args.num_tokenized_tokens )) return val_dataloader
def generate_samples(model, tokenizer, args, device): context_count = 0 keys = ['text', 'types', 'mask', 'mask_labels', 'pad_mask'] datatype = torch.int64 keys2 = ['clickscores', 'hrsscores'] datatype2 = torch.float64 model.eval() fout = open(args.output_path, 'w', encoding='utf-8') with torch.no_grad(): data_ietrator = binglr_iterator_dataset([args.valid_data], run_once=True, max_seq_len=args.seq_length, mask_lm_prob=0.15, max_preds_per_seq=20, tokenizer=tokenizer, train=False) while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: data = next(data_ietrator) if sample is None: terminate_runs = 1 else: # Unpack. tokens = data['text'] types = data['types'] loss_mask = data['mask'] lm_labels = data['mask_labels'] padding_mask = data['pad_mask'] clickscores = data['clickscores'] hrsscores = data['hrsscores'] sample_id = data['sample_id'] # Get the masks and postition ids. else: tokens = np.array([0] * seq_length) types = np.array([0] * seq_length) loss_mask = np.array([0] * seq_length) lm_labels = np.array([0] * seq_length) padding_mask = np.array([0] * seq_length) clickscores = np.array([0.0]) hrsscores = np.array([0.0]) sample_id = np.array([0]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return tokens_tensor = torch.cuda.LongTensor(tokens).view(1, -1) types_tensor = torch.cuda.LongTensor(types).view(1, -1) loss_mask_tensor = torch.cuda.LongTensor(loss_mask).view(1, -1) lm_labels_tensor = torch.cuda.LongTensor(lm_labels).view(1, -1) padding_mask_tensor = torch.cuda.LongTensor( padding_mask_mask).view(1, -1) clickscores_tensor = torch.cuda.FloatTensor(clickscores) hrsscores_tensor = torch.cuda.FloatTensor(hrsscores) batch_size, seq_length = tokens.size() attention_mask_tensor = ( torch.ones_like(padding_mask, device=padding_mask.device) - padding_mask).view(batch_size, 1, seq_length, 1) * ( torch.ones_like(padding_mask, device=padding_mask.device) - padding_mask).view(batch_size, 1, 1, seq_length) position_ids = torch.arange(seq_length, dtype=torch.long, device=tokens.device) position_ids_tensor = position_ids.unsqueeze(0).expand_as(tokens) torch.distributed.broadcast(tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(types_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(clickscores_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(hrsscores_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(attention_mask_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(position_ids_mask_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) start_time = time.time() counter = 0 _, hrs_scores, _, _ = model(tokens.contiguous(), position_ids.contiguous(), attention_mask.contiguous(), types.contiguous()) if mpu.get_model_parallel_rank() == 0: os.system('clear') fout('\t'.join( [sample_id, hrs_scores.detach().clone().cpu().numpy()]) + '\n') torch.distributed.barrier(group=mpu.get_model_parallel_group())
def generate_samples(model, tokenizer, args, device): context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: if args.input_text: raw_text = open(args.input_text).read().strip() else: raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input( "\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: #context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.encode(raw_text) context_length = len(context_tokens) if context_length >= args.seq_length // 2: print("\nContext length", context_length, \ "\nPlease give smaller context (half of the sequence length)!") continue else: #context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.encode("空文本") context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return pad_id = tokenizer.encoder['<pad>'] args.eod_token = tokenizer.encoder['<eod>'] if context_length < args.seq_length: context_tokens.extend([pad_id] * (args.seq_length - context_length)) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor[0].item() tokens, attention_mask, position_ids = get_batch( context_tokens_tensor, device, args) start_time = time.time() counter = 0 org_context_length = context_length past_key_values = None while counter < (org_context_length + args.out_seq_length): if counter == 0: logits, past_key_values = model( tokens[:, :context_length], position_ids[:, :context_length], attention_mask[:, :, :context_length, :context_length], past_key_values=past_key_values, use_cache=True) logits = logits[:, context_length - 1, :] else: logits, past_key_values = model( tokens[:, context_length - 1:context_length], position_ids[:, context_length - 1:context_length], attention_mask[:, :, context_length - 1, :context_length], past_key_values=past_key_values, use_cache=True) logits = logits[:, 0, :] past_key_values = [x.half() for x in past_key_values] logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1) tokens[0, context_length] = prev[0] torch.distributed.broadcast( tokens, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length += 1 counter += 1 output_tokens_list = tokens.view(-1).contiguous() decode_tokens = tokenizer.decode(output_tokens_list.tolist()) token_end = decode_tokens.find("<eod>") if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0 or token_end != -1): os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) trim_decode_tokens = decode_tokens[ len(raw_text):decode_tokens.find("<eod>")] print("\nCPM:", trim_decode_tokens, flush=True) if token_end != -1: #print(token_end) break if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) output_tokens_list = tokens.view(-1).contiguous() decode_tokens = tokenizer.decode(output_tokens_list.tolist()) trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens. find("<eod>")] print("\nCPM:", trim_decode_tokens, flush=True) #print(token_end) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1 if args.input_text: break
def generate_samples(model, tokenizer, args, device): context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = 'what is freedom of speech?' while not raw_text: print('Prompt should not be empty!') raw_text = 'what is freedom of speech?' if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.EncodeAsIds( raw_text).tokenization context_length = len(context_tokens) if context_length >= args.seq_length // 2: print("\nContext length", context_length, \ "\nPlease give smaller context (half of the sequence length)!") continue else: context_tokens = tokenizer.EncodeAsIds( "EMPTY TEXT").tokenization context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return pad_id = tokenizer.get_command('pad').Id if context_length < args.seq_length: context_tokens.extend([pad_id] * (args.seq_length - context_length)) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor[0].item() tokens, attention_mask, position_ids = get_batch( context_tokens_tensor, device, args) start_time = time.time() counter = 0 org_context_length = context_length while counter < (org_context_length + args.out_seq_length): logits = model(tokens, position_ids, attention_mask) logits = logits[:, context_length - 1, :] / args.temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1) tokens[0, context_length] = prev[0] context_length += 1 counter += 1 output_tokens_list = tokens.view(-1).contiguous() decode_tokens = tokenizer.DecodeIds( output_tokens_list.tolist()) token_end = decode_tokens.find("<|endoftext|>") if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0 or token_end != -1): os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) trim_decode_tokens = decode_tokens[ len(raw_text):decode_tokens.find("<|endoftext|>")] print("\nGPT2:", trim_decode_tokens, flush=True) if token_end != -1: break if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) output_tokens_list = tokens.view(-1).contiguous() decode_tokens = tokenizer.DecodeIds( output_tokens_list.tolist()) trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens. find("<|endoftext|>")] print("\nGPT2:", trim_decode_tokens, flush=True) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1
def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider, args): """XXX""" (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Number of train/valid/test samples. train_iters = args.train_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_iters * global_batch_size, eval_iters * global_batch_size, test_iters * global_batch_size] print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) # Build dataloders. train_dataloader = make_data_loader(train_ds) valid_dataloader = make_data_loader(valid_ds) test_dataloader = make_data_loader(test_ds) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = flags[0].item() args.do_valid = flags[1].item() args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = args.iteration % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'. format(train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'. format(valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator
def __init__(self, config, batch_slices, seq_slices, distributed_init_method, world_size, data_parallel_size, model_parallel_size, pipeline_parallel_size, rank, local_rank, mixed_precision=False, use_mpi=False, init_process_group=False, checkpoint_gradients=False): self.config = config self.batch_slices = batch_slices self.seq_slices = seq_slices torch.cuda.set_device(local_rank) if init_process_group: dist.init_process_group( backend='nccl', init_method=distributed_init_method, world_size=world_size, rank=rank, ) dist.all_reduce(torch.zeros(1).cuda()) mpu.initialize_model_parallel(model_parallel_size, pipeline_parallel_size) set_random_seed(0) mpu.model_parallel_cuda_manual_seed(0) self.rank = rank self.local_rank = local_rank self.world_size = world_size self.data_parallel_size = data_parallel_size self.model_parallel_size = model_parallel_size self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_group_rank = mpu.get_pipeline_parallel_group_rank( ) self.data_parallel_group = mpu.get_data_parallel_group() self.model_parallel_group = mpu.get_model_parallel_group() self.pipeline_parallel_pred_group = mpu.get_pipeline_parallel_pred_group( ) self.pipeline_parallel_succ_group = mpu.get_pipeline_parallel_succ_group( ) self.model_parallel_src_rank = mpu.get_model_parallel_src_rank() self.model_parallel_dst_rank = mpu.get_model_parallel_dst_rank() self.model_parallel_next_src_rank = ( self.model_parallel_src_rank + self.model_parallel_size if self.pipeline_parallel_group_rank < self.pipeline_parallel_size - 1 else None) self.model_parallel_prev_dst_rank = ( self.model_parallel_dst_rank - self.model_parallel_size if self.pipeline_parallel_group_rank > 0 else None) self.n_layers = (config.n_layers // pipeline_parallel_size + int(rank < config.n_layers % pipeline_parallel_size)) self.config = config self.mixed_precision = mixed_precision self.checkpoint_gradients = checkpoint_gradients self.layers = [] for _ in range(self.n_layers): l = ModelParallelTransformerLayer( self.config.embedding_dim, self.config.ffn_embedding_dim, self.config.num_attention_heads, device="cuda", checkpoint_gradients=self.checkpoint_gradients) self.layers.append(l.half() if self.mixed_precision else l) self.all_parameters = [] for layer in self.layers: self.all_parameters.extend(layer.parameters()) self.n_params = len(self.all_parameters) if self.mixed_precision: self.master_parameters = [ p.clone().detach().float() for p in self.all_parameters ] for p in self.master_parameters: p.requires_grad_() self.optimizer = optimizers.FusedAdam(self.master_parameters, lr=1e-10) else: self.optimizer = torch.optim.Adam(self.all_parameters, lr=1e-10)
def finetune(args, train_valid_datasets_provider, model_kwargs, forward_step=finetune_forward_step, end_of_epoch_callback_provider=None): """Main finetune function used across all tasks.""" global tokenizer timers = Timers() tokenizer = prepare_tokenizer(args) pretrain_glm.tokenizer = tokenizer if args.save: args.save = os.path.join(args.save, args.experiment_name) # Train and validation data loaders. timers('train/valid/test dataset/dataloder').start() train_dataloader, valid_dataloader = None, None train_block_dataloader, valid_block_dataloader = None, None if train_valid_datasets_provider is not None and args.epochs > 0: if mpu.get_model_parallel_rank() == 0: train_dataset, valid_dataset = train_valid_datasets_provider( args, tokenizer) train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataset, valid_dataset, args) if args.no_validation: valid_dataloader = None train_iters = torch.cuda.LongTensor([len(train_dataloader)]) else: train_iters = torch.cuda.LongTensor([0]) torch.distributed.broadcast(train_iters, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) if mpu.get_model_parallel_rank() != 0: args.train_iters_per_epoch = train_iters[0].item() args.train_iters = args.epochs * args.train_iters_per_epoch train_dataloader = FakeDataloader(args.train_iters_per_epoch) if args.no_validation: valid_dataloader = None else: valid_dataloader = FakeDataloader(None) if args.block_lm_ratio > 0.0: if mpu.get_model_parallel_rank() == 0: train_block_dataset, valid_block_dataset = train_valid_datasets_provider( args, tokenizer, pattern_text=True) train_block_dataloader = make_data_loader( train_block_dataset, tokenizer, args.batch_size * mpu.get_data_parallel_world_size(), args.train_iters, args, shuffle=True, block_collate=True) valid_block_dataloader = make_data_loader( valid_block_dataset, tokenizer, args.batch_size * mpu.get_data_parallel_world_size(), (args.train_iters // args.eval_interval + 1) * args.eval_iters, args, shuffle=True, block_collate=True) else: train_block_dataloader = FakeDataloader(args.train_iters) valid_block_dataloader = FakeDataloader(None) train_block_dataloader, valid_block_dataloader = iter( train_block_dataloader), iter(valid_block_dataloader) timers('train/valid/test dataset/dataloder').stop() # Build calback function. timers('callback function').start() end_of_epoch_callback, end_of_train_callback = None, None if end_of_epoch_callback_provider is not None: if train_valid_datasets_provider is not None and args.epochs > 0 and not args.no_validation: end_of_epoch_callback = end_of_epoch_callback_provider( args, tokenizer, is_test=False) end_of_train_callback = end_of_epoch_callback_provider(args, tokenizer, is_test=True) timers('callback function').stop() # Build model, optimizer and learning rate scheduler. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer( args, **model_kwargs) timers('model and optimizer').stop() # If pretrained checkpoint is provided and we have not trained for # any iteration (i.e., iteration is zero), then load the pretrained # checkpoint. timers('pretrained checkpoint').start() if args.load_pretrained is not None and not args.pretrained_bert: task_tokens = None if args.continuous_prompt and args.prompt_init: if mpu.get_model_parallel_rank() == 0: dataset = train_dataloader.dataset processor, pvp = dataset.processor, dataset.pvp task_tokens = [] for label in processor.get_labels(): verbalizer = pvp.verbalize(label)[0] verbalizer_ids = tokenizer.EncodeAsIds( verbalizer).tokenization task_tokens += verbalizer_ids print_rank_0("Task tokens: " + tokenizer.DecodeIds(task_tokens)) num_task_tokens = len(task_tokens) else: num_task_tokens, task_tokens = 0, [] num_task_tokens = torch.cuda.LongTensor([num_task_tokens]) torch.distributed.broadcast(num_task_tokens, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_task_tokens = num_task_tokens.item() if num_task_tokens > 0: if mpu.get_model_parallel_rank() == 0: task_tokens = torch.cuda.LongTensor(task_tokens) else: task_tokens = torch.empty( num_task_tokens, device=torch.cuda.current_device(), dtype=torch.long) torch.distributed.broadcast( task_tokens, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) task_tokens = task_tokens.tolist() with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): load_pretrained(model, args.load_pretrained, args, task_tokens=task_tokens) # This is critical when only model is loaded. We should make sure # master parameters are also updated. if args.fp16 and optimizer is not None: if args.deepspeed: optimizer.refresh_fp32_params() else: optimizer._model_params_to_master_params() if args.load is not None: with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): load_checkpoint(model, optimizer, lr_scheduler, args, no_deepspeed=args.no_deepspeed_load) # This is critical when only model is loaded. We should make sure # master parameters are also updated. if args.fp16 and optimizer is not None: if args.deepspeed: optimizer.refresh_fp32_params() else: optimizer._model_params_to_master_params() torch.distributed.barrier() timers('pretrained checkpoint').stop() args.iteration = 0 summary_writer = None if torch.distributed.get_rank() == 0: args.log_dir = get_log_dir(base=args.summary_dir, name=args.experiment_name) if os.path.exists(os.path.join(args.log_dir, "test_results.json") ) and args.load is None and not args.overwrite: raise ValueError( "Output directory ({}) already exists and is not empty.". format(args.log_dir)) summary_writer = get_sample_writer(log_dir=args.log_dir, iteration=args.iteration) print_and_save_args(args, verbose=True, log_dir=args.log_dir) # Print setup timing. print_rank_0('done with setups ...') timers.log([ 'train/valid/test dataset/dataloder', 'callback function', 'model and optimizer', 'pretrained checkpoint' ]) print_rank_0('training ...') # Finetune the model. score_dict = None if train_dataloader is not None and args.epochs > 0: if args.block_lm_ratio > 0.0: forward_step = mix_forward_step best_iteration = _train(model, optimizer, lr_scheduler, forward_step, (train_dataloader, train_block_dataloader), (valid_dataloader, valid_block_dataloader), end_of_epoch_callback, args, timers, summary_writer=summary_writer) if end_of_train_callback is not None and best_iteration is not None: with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): args.load = os.path.join(args.save, "best") load_checkpoint(model, optimizer, lr_scheduler, args, no_load_optim=True, no_deepspeed=True) args.load = None torch.distributed.barrier() if end_of_train_callback is not None: score_dict = end_of_train_callback(model, epoch=-1, output_predictions=True) # Or just evaluate. else: if end_of_train_callback is not None: print_rank_0('evaluation only mode, setting epoch to -1') score_dict = end_of_train_callback(model, epoch=-1, output_predictions=True) if score_dict is not None and torch.distributed.get_rank() == 0: score_dict.update({"type": "test"}) with open(os.path.join(args.log_dir, "test_results.json"), "w") as output: output.write(json.dumps(score_dict) + "\n") print_rank_0('done :-)')
def generate_samples_interactive(model, tokenizer, args): print_frequency = 24 context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: os.system('clear') raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.EncodeAsIds( raw_text).tokenization context_length = len(context_tokens) # if context_length >=args.seq_length//2: # print("\nContext length", context_length, \ # "\nPlease give smaller context (half of the sequence length)!") # continue else: context_tokens = tokenizer.EncodeAsIds( "EMPTY TEXT").tokenization context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return start_time = time.time() token_stream = get_token_stream(model, [context_tokens], tokenizer, args) for counter, decode_tokens in enumerate(token_stream): # token_end = decode_tokens.find("<|endoftext|>") # if token_end > 0: # break decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank( ) == 0 and counter % print_frequency == 0: os.system('clear') #print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.DecodeIds(decode_tokens) #print("\nGPT2:", trim_decode_tokens, flush=True) #print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True) if mpu.get_model_parallel_rank() == 0 and decode_tokens: os.system('clear') #print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.DecodeIds(decode_tokens) #print("\nGPT2:", trim_decode_tokens, flush=True) #print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1 if mpu.get_model_parallel_rank() == 0: input("\nPress any key to continue >>>")
def generate_samples_input_from_file(model, tokenizer, args): if args.sample_input_file == "": if mpu.get_model_parallel_rank() == 0: print("args.sample_input_file CAN NOT BE empty!\n") return if mpu.get_model_parallel_rank() == 0: fname = open(args.sample_input_file, "r") all_raw_text = fname.readlines() input_count = len(all_raw_text) input_pos = 0 if args.sample_output_file == "": print( "Argument: sample-output-file can't be empty, setting it to\n") print("\t args.sample_input_file.out") args.sample_output_file = args.sample_input_file + ".out" fname_out = open(args.sample_output_file, "w+") context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = all_raw_text[input_pos] input_pos += 1 if input_pos == input_count: raw_text = "stop" if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.EncodeAsIds( raw_text).tokenization context_length = len(context_tokens) # if context_length >=args.seq_length//2: # print("\nContext length", context_length, \ # "\nPlease give smaller context (half of the sequence length)!") # continue else: context_tokens = tokenizer.EncodeAsIds( "EMPTY TEXT").tokenization context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return start_time = time.time() token_stream = get_token_stream(model, [context_tokens], tokenizer, args) for counter, decode_tokens in enumerate(token_stream): # token_end = decode_tokens.find("<|endoftext|>") # if token_end > 0: # break decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0 and decode_tokens: os.system('clear') #print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.DecodeIds(decode_tokens) #print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True) fname_out.write("\nContext:") fname_out.write(raw_text) fname_out.write("\n\nMegatron-LM:") fname_out.write(trim_decode_tokens) #fname_out.write(trim_decode_tokens.replace("\n", "\n\n")) fname_out.write("\n") raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1
def generate_samples(model, tokenizer, args, device): context_count = 0 model.eval() output_path = "./samples" print("We're in.1") if not os.path.exists(output_path): os.makedirs(output_path) print("We're in.2") output_path = os.path.join( output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt") with torch.no_grad(), open(output_path, "w") as output: while True: print("We're in.") torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: output.write(raw_text) context_tokens = tokenizer.EncodeAsIds( raw_text).tokenization context_length = len(context_tokens) if context_length >= args.seq_length: print("\nContext length", context_length, \ "\nPlease give smaller context (half of the sequence length)!") continue else: context_tokens = tokenizer.EncodeAsIds( "EMPTY TEXT").tokenization context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return # pad_id = tokenizer.get_command('pad').Id # if context_length < args.out_seq_length: # context_tokens.extend([pad_id] * (args.out_seq_length - context_length)) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor[0].item() tokens, attention_mask, position_ids = get_batch( context_tokens_tensor, device, args) start_time = time.time() counter, mems = 0, [] org_context_length = context_length while counter < (args.out_seq_length - org_context_length): if counter == 0: logits, *mems = model(tokens, position_ids, attention_mask, *mems) else: index = org_context_length + counter logits, *mems = model( tokens[:, index - 1:index], tokens.new_ones((1, 1)) * (index - 1), tokens.new_ones(1, 1, 1, args.mem_length + 1, device=tokens.device, dtype=torch.float), *mems) logits = logits[:, -1] logits /= args.temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1)[0] tokens = torch.cat((tokens, prev.view(1, 1)), dim=1) context_length += 1 counter += 1 output_tokens_list = tokens.view(-1).contiguous() decode_tokens = tokenizer.DecodeIds( output_tokens_list.tolist()) is_end = prev == args.eod_token if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end): os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) trim_decode_tokens = decode_tokens[ len(raw_text):decode_tokens.find("<|endoftext|>")] print("\nGPT2:", trim_decode_tokens, flush=True) if is_end: break if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) print("\nContext:", raw_text, flush=True) output_tokens_list = tokens.view(-1).contiguous() decode_tokens = tokenizer.DecodeIds( output_tokens_list.tolist()) trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens. find("<|endoftext|>")] print("\nGPT2:", trim_decode_tokens, flush=True) output.write(trim_decode_tokens + "\n") raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1
def generate_samples_input_from_file(model, tokenizer, args): if mpu.get_model_parallel_rank() == 0: all_raw_text = args.text input_count = len(all_raw_text) input_pos = 0 context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = all_raw_text[input_pos] input_pos += 1 if input_pos == input_count: raw_text = "stop" if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.EncodeAsIds( raw_text).tokenization context_length = len(context_tokens) if context_length >= args.seq_length // 2: print( "\nContext length", context_length, "\nPlease give smaller context (half of the sequence length)!" ) continue else: context_tokens = tokenizer.EncodeAsIds( "EMPTY TEXT").tokenization context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return start_time = time.time() token_stream = get_token_stream(model, [context_tokens], tokenizer, args) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) trim_decode_tokens = tokenizer.DecodeIds( decode_tokens)[len(raw_text):] # print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True) yield trim_decode_tokens raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1
def __call__(self, text=None, input_ids=None, labels=None, **kwargs): if input_ids is None: if text is None: text = "" input_ids = torch.cuda.LongTensor( [self.tokenizer(text)['input_ids']]) if isinstance(input_ids, list): input_ids = torch.cuda.LongTensor(input_ids) if isinstance(labels, list): labels = torch.cuda.LongTensor(labels) res = [] if labels is not None: lbls = labels else: lbls = [None] * len(input_ids) loss = None original_context_length = 0 for tokens, lbl in zip(input_ids, lbls): context_tokens = tokens.tolist() context_length = len(context_tokens) original_context_length = len(context_tokens) if context_length < self.seq_len: context_tokens.extend([self.pad_token_id] * (self.seq_len - context_length)) if labels is not None: lbl = lbl.tolist() lbl.extend([self.pad_token_id] * (self.seq_len - context_length)) lbl = torch.cuda.LongTensor(lbl) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor([context_length]) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) # context_length = context_length_tensor[0].item() tokens = context_tokens_tensor tokens = tokens.view(1, -1).contiguous() tokens = tokens.to(torch.cuda.current_device()) attention_mask, loss_mask, position_ids = get_masks_and_position_ids( tokens, self.pad_token_id, False, False) lm_logits = self.model(tokens, position_ids, attention_mask) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = lbl[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=self.pad_token_id) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) res.append((lm_logits, loss)) logits = torch.cat([x[0] for x in res], dim=0)[:, :original_context_length, :] if loss is not None: loss = [x[1] for x in res] return ModelOutput(logits, loss)