def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. model = model_provider_func() # Print number of parameters. if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training.""" if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) return model if args.DDP_impl == 'local': model = LocalDDP(model) return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl))
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None): """A unified checkpoint name.""" if release: directory = 'release' else: directory = 'iter_{:07d}'.format(iteration) return os.path.join(checkpoints_path, directory, 'mp_rank_{:02d}'.format( mpu.get_model_parallel_rank() if mp_rank is None else mp_rank), 'model_optim_rng.pt')
def get_total_params(model): # Print number of parameters. if mpu.get_data_parallel_rank() == 0: params = sum([p.nelement() for p in model.parameters()]) print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), params), flush=True) else: params = 0 total_n_parameters = torch.tensor([params ]).cuda(torch.cuda.current_device()) torch.distributed.all_reduce(total_n_parameters) total_n_parameters = total_n_parameters.item() return total_n_parameters
def __init__(self, model, forward_step_fn, neox_args, batch_size=None): self.cache_hook = base.CacheHook(None) self.model = model self.neox_args = neox_args self.tokenizer = neox_args.tokenizer self._device = torch.device(f"cuda:{neox_args.local_rank}") self._eot_token_id = neox_args.tokenizer.eod_id self._max_length = neox_args.max_position_embeddings // 2 self._max_gen_toks = 128 self._vocab_size = neox_args.padded_vocab_size # parallelism args: self.is_main = neox_args.rank == 0 self.is_local_main = neox_args.local_rank == 0 self.is_model_parallel = neox_args.model_parallel_size > 1 self.is_pipe_parallel = self.model.is_pipe_parallel self.is_data_parallel = self.model.is_data_parallel self.is_last_stage = ( True if not self.is_pipe_parallel else model.is_last_stage() ) # only the last stage of the pipeline model will receive the logits self.dp_world_size = mpu.get_data_parallel_world_size() self.dp_rank = mpu.get_data_parallel_rank() self.dp_group = mpu.get_data_parallel_group() self.is_mp_rank_0 = mpu.get_model_parallel_rank() == 0 self._batch_size = batch_size or ( neox_args.batch_size * self.dp_world_size ) # default batch size to bs per gpu * dp size # some utility functions: # we need to patch tokenizer methods, because lm_eval uses them internally: self.tokenizer.encode = self.tokenizer.tokenize self.tokenizer.decode = self.tokenizer.detokenize self._forward_step_fn = partial( forward_step_fn, neox_args=neox_args, timers=None, return_logits=True ) self.generate = partial( generate_samples_from_prompt, neox_args=neox_args, model=model, maximum_tokens=self._max_gen_toks, temperature=0.0, )
def generate_samples_interactive( neox_args, model, maximum_tokens: int = 64, eos_token_id: int = None, recompute: bool = False, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.0, ): """ Generates samples unconditionially (no prompt) and yields them in a dictionary. neox_args: NeoXArgs. model: a Megatron model maximum_tokens: maximum number of tokens to be generated eos_token_id: end of text token at which completion is terminated, even if max_tokes count has not been reached recompute: flag indicating whether a cache is used for already forwarded tokens (true) or whether all tokens are recomputed at every iteration (false) temperature (default 0.0): exponential scaling output distribution ("higher == more risk") top_k (default 0): integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token. top_p (default 0.0): float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p. note: greedy decoding is used if temperature is 0.0, top_k is 0 and top_p is 0.0 yields: dict containing the following fields: - 'context' (the input) - 'text' (the completion) - 'length' (the length of the completion in number of tokens) - 'finished': - 'message': a messaged associated with the generation procedure, can be a warning or error - 'duration_seconds': duration of the generation in seconds """ while True: model.module.clear_cache() # clear kv cache between batches torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if torch.distributed.is_initialized() and torch.distributed.get_rank( ) == 0: os.system("clear") raw_text = input("Context prompt >>> ") context_tokens = neox_args.tokenizer.tokenize(raw_text) if len(context_tokens) == 0: context_tokens = [neox_args.tokenizer.eod] context_length = len(context_tokens) if context_length >= (neox_args.seq_length - 1): print_rank_0("\nContext length" + str(context_length) + "\nReached max sequence length!") terminate_runs = 1 else: context_tokens = neox_args.tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs = broadcast_terminate_signal(terminate_runs) if terminate_runs == 1: return for ( batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index, is_done, ) in stream_tokens( neox_args=neox_args, model=model, context_tokens=[context_tokens], eos_token_id=eos_token_id, maximum_tokens=maximum_tokens, recompute=recompute, temperature=temperature, top_k=top_k, top_p=top_p, ): if mpu.get_model_parallel_rank() == 0: generated_tokens = (batch_context_tokens[0].cpu( ).numpy().tolist()[batch_token_generation_start_index[0].item( ):batch_token_generation_end_index[0].item()]) generated_text = neox_args.tokenizer.detokenize( generated_tokens) print_rank_0("Generated Text: " + generated_text) if torch.distributed.is_initialized() and torch.distributed.get_rank( ) == 0: _ = input("\n<press enter to continue>")
def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider): """XXX""" args = get_args() (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Number of train/valid/test samples. train_iters = args.train_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_iters * global_batch_size, eval_iters * global_batch_size, test_iters * global_batch_size] print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) # Build dataloders. train_dataloader = make_data_loader(train_ds) valid_dataloader = make_data_loader(valid_ds) test_dataloader = make_data_loader(test_ds) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = flags[0].item() args.do_valid = flags[1].item() args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = args.iteration % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'. format(train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'. format(valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator
def generate_samples_input_from_file(model): args = get_args() tokenizer = get_tokenizer() # Read the sample file and open the output file. assert args.sample_input_file is not None, \ 'sample input file is not provided.' if mpu.get_model_parallel_rank() == 0: fname = open(args.sample_input_file, "r") all_raw_text = fname.readlines() input_count = len(all_raw_text) input_pos = 0 if args.sample_output_file is None: sample_output_file = args.sample_input_file + ".out" print('could not find `sample-output-file`, setting ' 'it to {}'.format(sample_output_file)) else: sample_output_file = args.sample_output_file fname_out = open(sample_output_file, "w+") context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = all_raw_text[input_pos] input_pos += 1 if input_pos == input_count: raw_text = "stop" if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return token_stream = get_token_stream(model, [context_tokens]) for _, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) fname_out.write("\nContext:") fname_out.write(raw_text) fname_out.write("\n\nMegatron-LM:") fname_out.write(trim_decode_tokens) fname_out.write("\n") raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1
def generate_samples_interactive(model, print_frequency=24): args = get_args() tokenizer = get_tokenizer() context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: os.system('clear') raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0 and \ counter % print_frequency == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1 if mpu.get_model_parallel_rank() == 0: input("\nPress any key to continue >>>")
def __init__( self, neox_args, attention_mask_func, init_method, output_layer_init_method, layer_number, rpe=None, rotary=False, use_cache=False, parallel_output=False, ): super().__init__() self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling self.use_cache = use_cache self.attention_softmax_in_fp32 = neox_args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = layer_number # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( neox_args.hidden_size, neox_args.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( neox_args.num_attention_heads, world_size) self.pos_emb = neox_args.pos_emb # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, gather_output=False, init_method=init_method, ) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = max(1, self.layer_number) self.norm_factor *= coeff self.rpe = rpe if self.pos_emb == "alibi": self.alibi_embed = AliBi( neox_args.num_attention_heads, neox_args.model_parallel_size, mpu.get_model_parallel_rank(), ) # TODO: this arg shouldn't need to be passed in - get from neox_args if rotary: if neox_args.rotary_pct == 1: self.rotary_ndims = None else: assert neox_args.rotary_pct < 1 self.rotary_ndims = int(self.hidden_size_per_attention_head * neox_args.rotary_pct) dim = (self.rotary_ndims if self.rotary_ndims is not None else self.hidden_size_per_attention_head) self.rotary_emb = RotaryEmbedding(dim, base=neox_args.rotary_emb_base, precision=neox_args.params_dtype) else: self.rotary_emb = None self.attention_type = neox_args.attention_config[layer_number] self.sparse = self.attention_type != "global" if self.sparse: self.sparse_attn = configure_sparse_attention( neox_args, self.attention_type, self.num_attention_heads_per_partition, mpu=mpu, ) else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, input_in_bf16=self.bf16, fusion_type=get_fusion_type(neox_args), mask_func=self.attention_mask_func, softmax_in_fp32=self.attention_softmax_in_fp32, scale=coeff, ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = nn.Dropout(neox_args.attention_dropout) # Output. self.dense = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, parallel_output=parallel_output, )
def build_train_valid_test_data_iterators(neox_args): """XXX""" (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Ensure only the first/last pipeline stages have data loaders if neox_args.is_pipe_parallel: is_first_stage = mpu.get_pipe_parallel_rank() == 0 is_last_stage = mpu.get_pipe_parallel_rank( ) == mpu.get_pipe_parallel_world_size() - 1 pipe_load = is_first_stage or is_last_stage else: pipe_load = True # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0 and pipe_load: # Number of train/valid/test samples. train_iters = neox_args.train_iters eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters test_iters = neox_args.eval_iters train_val_test_num_samples = [ train_iters * neox_args.train_batch_size, eval_iters * neox_args.train_batch_size, test_iters * neox_args.train_batch_size ] if neox_args.train_data_paths: # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset train_weights, train_num_samples = get_normalized_weights_and_num_samples( neox_args.train_data_weights, train_val_test_num_samples[0]) valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( neox_args.valid_data_weights, train_val_test_num_samples[1]) test_weights, test_num_samples = get_normalized_weights_and_num_samples( neox_args.test_data_weights, train_val_test_num_samples[2]) # build individual datasets train_datasets, valid_datasets, test_datasets = build_weighted_datasets(neox_args, train_num_samples, valid_num_samples, test_num_samples, train_weights, valid_weights, test_weights, \ build_index_mappings=not neox_args.weight_by_num_documents) if neox_args.weight_by_num_documents: # gets the number of documents in each datapath get_num_docs_list = lambda datasets: [ dataset.indexed_dataset.sizes.shape[0] for dataset in datasets ] train_num_docs, valid_num_docs, test_num_docs = get_num_docs_list( train_datasets), get_num_docs_list( valid_datasets), get_num_docs_list(test_datasets) # builds weights according to alpha + the number of docs fn = partial(weights_by_num_docs, alpha=neox_args.weighted_sampler_alpha) train_weights, valid_weights, test_weights = fn( train_num_docs), fn(valid_num_docs), fn(test_num_docs) train_weights, train_num_samples = get_normalized_weights_and_num_samples( train_weights, train_val_test_num_samples[0]) valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( valid_weights, train_val_test_num_samples[1]) test_weights, test_num_samples = get_normalized_weights_and_num_samples( test_weights, train_val_test_num_samples[2]) # rebuild datasets weighted according to new weights train_datasets, valid_datasets, test_datasets = build_weighted_datasets( neox_args, train_num_samples, valid_num_samples, test_num_samples, train_weights, valid_weights, test_weights) if train_datasets: train_ds = BlendableDataset(train_datasets, train_weights) if valid_datasets: valid_ds = BlendableDataset(valid_datasets, valid_weights) if test_datasets: test_ds = BlendableDataset(test_datasets, test_weights) else: # when just data_path is provided # split dataset into train, valid and test from data_path train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=neox_args.data_path, data_impl=neox_args.data_impl, splits_string=neox_args.split, train_valid_test_num_samples=train_val_test_num_samples, seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup)) # Build dataloders. train_dataloader = make_data_loader(train_ds, neox_args=neox_args) valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args) test_dataloader = make_data_loader(test_ds, neox_args=neox_args) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and neox_args.train_iters > 0 do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 do_test = test_dataloader is not None and neox_args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. if neox_args.is_pipe_parallel: # Only first/last pipeline stages have data loaders, so pipeline parallelism should # broadcast globally instead of just the model parallel group. torch.distributed.broadcast(flags, src=0) else: torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = (neox_args.iteration * neox_args.gradient_accumulation_steps) % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'.format( train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = ((neox_args.iteration * neox_args.gradient_accumulation_steps) // neox_args.eval_interval) * \ neox_args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'.format( valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator
def is_mp_rank_0(): """True if mp rank == 0""" return mpu.get_model_parallel_rank() == 0