def get_token_stream(model, context_tokens): args = get_args() tokenizer = get_tokenizer() context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids) for tokens, lengths in batch_token_iterator: context_length += 1 yield tokens[:, :context_length], lengths
def restore_weights(self, restore_path: str): """Restores module/model's weights. For model parallel checkpoints the directory structure should be restore_path/mp_rank_0X/model_optim_rng.pt Args: restore_path (str): restore_path should a file or a directory if using model parallel """ self._restore_path = restore_path if os.path.isfile(restore_path): self._load_checkpoint(restore_path) elif os.path.isdir(restore_path): # need model parallel groups to restore model parallel checkpoints if model_parallel_is_initialized(): model_parallel_rank = torch.distributed.get_rank( group=get_model_parallel_group()) mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt' self._load_checkpoint(mp_restore_path) else: logging.info( f'torch.distributed not initialized yet. Will not restore model parallel checkpoint' ) else: logging.error( f'restore_path: {restore_path} must be a file or directory.')
def restore_weights(self, restore_path: str): """Restores module/model's weights. For model parallel checkpoints the directory structure should be restore_path/mp_rank_0X/model_optim_rng.pt Args: restore_path (str): restore_path should a file or a directory if using model parallel """ self._restore_path = restore_path if os.path.isfile(restore_path): logging.info(f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism') state_dict = torch.load(restore_path) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict(state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) logging.info(f"weights restored from {restore_path}") elif os.path.isdir(restore_path): # need model parallel groups to restore model parallel checkpoints if model_parallel_is_initialized(): model_parallel_rank = torch.distributed.get_rank(group=get_model_parallel_group()) mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt' logging.info(f'Restoring model parallel checkpoint from: {mp_restore_path}') state_dict = torch.load(mp_restore_path) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict(state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) else: logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint') else: logging.error(f'restore_path: {restore_path} must be a file or directory.')
def _unscale_main_grads_and_check_for_nan(self): main_grads = [] # fp32 params fromm float16 ones. for main_group in self.fp32_from_float16_groups: for main_param in main_group: if main_param.grad is not None: main_grads.append(main_param.grad.data) # Append fp32 parameters. for main_group in self.fp32_from_fp32_groups: for main_param in main_group: if main_param.grad is not None: main_grads.append(main_param.grad.data) # Reset found inf. self.found_inf.fill_(0.0) # Unscale and set found inf/nan torch._amp_foreach_non_finite_check_and_unscale_( main_grads, self.found_inf, self.grad_scaler.inv_scale) # Update across all model parallel instances. torch.distributed.all_reduce(self.found_inf, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) # Check for nan. found_inf_flag = (self.found_inf.item() > 0) return found_inf_flag
def init_model_parallel(self, global_rank: int, world_size: int) -> None: """ Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if torch.distributed.is_initialized(): mpu.initialize_model_parallel(app_state.model_parallel_size) app_state.model_parallel_group = mpu.get_model_parallel_group() app_state.data_parallel_group = mpu.get_data_parallel_group() app_state.model_parallel_rank = mpu.get_tensor_model_parallel_rank( ) app_state.data_parallel_rank = mpu.get_data_parallel_rank() app_state.data_parallel_size = mpu.get_data_parallel_world_size( ) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}') # TODO: get random seed from PTL seed = os.environ.get("PL_GLOBAL_SEED", 1234) # random seed must be set for megatron model parallel init _set_random_seed(seed)
def count_zeros_fp32(parameters): if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism total_num_zeros = 0.0 for param in parameters: grad_not_none = param.grad is not None is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if grad_not_none and is_not_shared and is_not_tp_duplicate: grad = param.grad.detach() num_zeros = grad.numel() - torch.count_nonzero(grad) total_num_zeros = num_zeros + total_num_zeros # Sum across all model-parallel GPUs. torch.distributed.all_reduce(total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_num_zeros = total_num_zeros.item() return total_num_zeros
def broadcast_terminate_signal(terminate_runs: int): """Send signal to all workers to terminate if we've finished the process""" terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) return terminate_runs_tensor[0].item()
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None: """ Override for LightningModule DDP initialization. Initializes Megatron-LM model parallel if using model parallelism. Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_gpus is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ LightningModule.init_ddp_connection(self, global_rank, world_size, is_slurm_managing_tasks) app_state = AppState() # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: if app_state.model_parallel_group is None: mpu.initialize_model_parallel(app_state.model_parallel_size) app_state.model_parallel_group = mpu.get_model_parallel_group() app_state.data_parallel_group = mpu.get_data_parallel_group() app_state.model_parallel_rank = torch.distributed.get_rank( group=app_state.model_parallel_group) app_state.data_parallel_rank = torch.distributed.get_rank( group=app_state.data_parallel_group) logging.info(f'mp_rank: {app_state.model_parallel_rank}') logging.info(f'dp_rank: {app_state.data_parallel_rank}')
def calc_params_l2_norm(model): """Calculate l2 norm of parameters """ args = get_args() if not isinstance(model, list): model = [model] # Remove duplicate params. params_data = [] for model_ in model: for param in model_.parameters(): is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if is_not_shared and is_not_tp_duplicate: if args.bf16: params_data.append(param.data.float()) else: params_data.append(param.data) # Calculate norm dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [params_data], False # no per-parameter norm ) norm_2 = norm * norm # Sum across all model-parallel GPUs. torch.distributed.all_reduce(norm_2, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) return norm_2.item() ** 0.5
def __init__(self, module): from megatron import mpu super().__init__( module, mp_group=mpu.get_model_parallel_group(), dp_group=mpu.get_data_parallel_group(), )
def flops_calculator(model, args, iteration_time): gpus_per_model = torch.distributed.get_world_size(group = mpu.get_model_parallel_group()) approx_parameters_in_billions = get_parameters_in_billions(model) giga_flops_per_model_per_train_step = approx_parameters_in_billions * args.batch_size * args.seq_length * 2.0 * 4.0 effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / (iteration_time * 1000.0 * gpus_per_model) print_rank_0(f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B")
def get_parameters_in_billions(model): gpus_per_model = torch.distributed.get_world_size( group=mpu.get_model_parallel_group()) approx_parameters_in_billions = sum([ p.ds_numel if hasattr(p, 'ds_id') else p.numel() for p in model.parameters() ]) * gpus_per_model / 1000000000.0 return approx_parameters_in_billions
def has_overflow(self, params): overflow = self.has_overflow_serial(params) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs overflow_gpu = torch.cuda.ByteTensor([overflow]) torch.distributed.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) overflow = overflow_gpu[0].item() return bool(overflow)
def fmoefy( model, num_experts=None, distributed_experts=True, hidden_hidden_size=None, top_k=None, ): r""" Replace MLP layers in a transformer-based model in Megatron by MoE. * `model` should be a standard Megatron model that has `model.language_model.transformer.layers` as transformer layers, which is an array of transformer blocks that contain an `mlp` member. * `distributed_expert` is set to True if different experts are located in different workers. Otherwise, the experts on the workers are identical, and they are trained in data-parallel mode. This can be useful when testing on small models that do not require high training throughput or large parameter capacity. Note that pipeline parallel is not supported yet. When distributed experts are enabled, their communicator should be Megatron's tensor_model_parall_comm x data_parallel_comm, which is not created. """ from megatron import get_args from megatron import mpu args = get_args() if num_experts is not None: args.num_experts = num_experts assert ( "num_experts" in args ), "num_experts should be specified in arguments or fmoefy function" if hidden_hidden_size is not None: args.hidden_hidden_size = hidden_hidden_size elif not hasattr(args, "hidden_hidden_size"): args.hidden_hidden_size = args.hidden_size * 4 if top_k is not None: args.top_k = top_k elif not hasattr(args, "top_k"): args.top_k = 2 # Set distributed_experts to None to use default setting in args if distributed_experts is not None: args.distributed_experts = distributed_experts for idx, l in enumerate(model.language_model.transformer.layers): l.mlp = MegatronMLP(args, mpu.get_model_parallel_group(), idx) # initialize gate hook global num_layers, balance_dict num_layers = len(model.language_model.transformer.layers) reset_gate_hook() return model
def generate_samples_interactive( neox_args, model, maximum_tokens: int = 64, eos_token_id: int = None, recompute: bool = False, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.0, ): """ Generates samples unconditionially (no prompt) and yields them in a dictionary. neox_args: NeoXArgs. model: a Megatron model maximum_tokens: maximum number of tokens to be generated eos_token_id: end of text token at which completion is terminated, even if max_tokes count has not been reached recompute: flag indicating whether a cache is used for already forwarded tokens (true) or whether all tokens are recomputed at every iteration (false) temperature (default 0.0): exponential scaling output distribution ("higher == more risk") top_k (default 0): integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token. top_p (default 0.0): float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p. note: greedy decoding is used if temperature is 0.0, top_k is 0 and top_p is 0.0 yields: dict containing the following fields: - 'context' (the input) - 'text' (the completion) - 'length' (the length of the completion in number of tokens) - 'finished': - 'message': a messaged associated with the generation procedure, can be a warning or error - 'duration_seconds': duration of the generation in seconds """ while True: model.module.clear_cache() # clear kv cache between batches torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if torch.distributed.is_initialized() and torch.distributed.get_rank( ) == 0: os.system("clear") raw_text = input("Context prompt >>> ") context_tokens = neox_args.tokenizer.tokenize(raw_text) if len(context_tokens) == 0: context_tokens = [neox_args.tokenizer.eod] context_length = len(context_tokens) if context_length >= (neox_args.seq_length - 1): print_rank_0("\nContext length" + str(context_length) + "\nReached max sequence length!") terminate_runs = 1 else: context_tokens = neox_args.tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs = broadcast_terminate_signal(terminate_runs) if terminate_runs == 1: return for ( batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index, is_done, ) in stream_tokens( neox_args=neox_args, model=model, context_tokens=[context_tokens], eos_token_id=eos_token_id, maximum_tokens=maximum_tokens, recompute=recompute, temperature=temperature, top_k=top_k, top_p=top_p, ): if mpu.get_model_parallel_rank() == 0: generated_tokens = (batch_context_tokens[0].cpu( ).numpy().tolist()[batch_token_generation_start_index[0].item( ):batch_token_generation_end_index[0].item()]) generated_text = neox_args.tokenizer.detokenize( generated_tokens) print_rank_0("Generated Text: " + generated_text) if torch.distributed.is_initialized() and torch.distributed.get_rank( ) == 0: _ = input("\n<press enter to continue>")
def stream_tokens( neox_args, model, context_tokens: List[List[int]], eos_token_id: int = None, maximum_tokens: int = None, recompute: bool = False, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.0, stop_tokens=None, ): """ iterator producing text completions neox_args: NeoXArgs. model: a Megatron model. context_tokens: the prompt to complete; unpadded list of lists of tokens ids context_lengths: lengths of context tokens of dimension [batch]; the context length records for each bach item how many non-padded tokens are provided eos_token_id: end of text token at which completion is terminated, even if max_tokes count has not been reached attention_mask: attention mask for megatron model. position_ids: position ids for positional encoding. maximum_tokens: maximum number of tokens to be generated; careful! if a batch input is provided maximum_tokens specifies the maximum number of forwards. longer batch items get less generated tokens. recompute: flag indicating whether a cache is used for already forwarded tokens (true) or whether all tokens are recomputed at every iteration (false) temperature (default 0.0): exponential scaling output distribution ("higher == more risk") top_k (default 0): integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token. top_p (default 0.0): float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p. note: greedy decoding is used if temperature is 0.0, top_k is 0 and top_p is 0.0 yields: ( tokens (completions from model), token_generation_start_index (token index per batch item for the first generated token), token_generation_end_index (token index per batch item for the last generated token), logits (logits which are so far computed, zeros otherwise), is_done (flag for each bach item indicating whether an eod token was generated) ) * each iteration adds a generated token to the context_tokens * output contains both context_tokens from input and generated tokens * if batch items have different lengths, the iterator will start at the first completion and return the unchanged input context token otherwise """ model.eval() # pad batch in order to allow conversion to tensor context_tokens, context_lengths = pad_batch( copy.deepcopy(context_tokens), pad_id=neox_args.tokenizer.eod, pad_len=neox_args.seq_length, ) # convert to tensor and broadcast context_tokens = torch.cuda.LongTensor(context_tokens) if stop_tokens: stop_tokens = torch.cuda.LongTensor(stop_tokens) if stop_tokens.ndim == 1: stop_tokens = stop_tokens.unsqueeze(0) # Make sure context tokens + start tokens are the same across all ranks token_generation_start_index = torch.cuda.LongTensor(context_lengths) torch.distributed.broadcast( context_tokens, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group(), ) torch.distributed.broadcast( token_generation_start_index, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group(), ) # get attention mask / position ids context_tokens, attention_mask, position_ids = get_batch( neox_args, context_tokens) # set variables eos_token_id = eos_token_id or neox_args.tokenizer.eod maximum_tokens = maximum_tokens or ( neox_args.seq_length - token_generation_start_index.max().item() - 1) batch_size = context_tokens.size(0) # get the context_index at which generation is to start # we start generation at the position where the smallest context ends token_index_to_generate = token_generation_start_index.min().item() first_token_index_to_generate = token_index_to_generate last_token_index_to_generate = min( neox_args.seq_length - 1, # never generate more than the model's sequence length token_index_to_generate + maximum_tokens - 1, ) with torch.no_grad(): # initialize generation variables state_is_done = torch.zeros([batch_size]).byte().cuda() token_generation_end_index = torch.ones([batch_size ]).long().cuda() * (-1) while token_index_to_generate <= last_token_index_to_generate: if recompute: # recompute all tokens model_inputs = ( context_tokens, position_ids, attention_mask, ) logits = forward_model(model, model_inputs, neox_args.is_pipe_parallel) if logits is not None: # if pipe parallel, not all ranks return logits generated_token_logits = logits[:, token_index_to_generate - 1, :] # [bs, seq, vocab_size] -> [bs, vocab_size] else: # use kv cache if token_index_to_generate == first_token_index_to_generate: tokens_to_use = context_tokens[:, :token_index_to_generate] positions_to_use = position_ids[:, : token_index_to_generate] else: tokens_to_use = context_tokens[:, token_index_to_generate - 1].view(batch_size, -1) positions_to_use = position_ids[:, token_index_to_generate - 1].view(batch_size, -1) model_inputs = ( tokens_to_use, # input_ids positions_to_use, # position_ids attention_mask, # attention_mask ) logits = forward_model(model, model_inputs, neox_args.is_pipe_parallel) if logits is not None: # if pipe parallel, not all ranks return logits generated_token_logits = ( logits[:, -1].view(batch_size, -1).contiguous() ) # [bs, seq, vocab_size] -> [bs, vocab_size] if logits is not None: # sample token id of the to be generated token if temperature == 0.0 and top_k == 0 and top_p == 0.0: generated_tokens = torch.argmax(generated_token_logits, dim=-1).view(-1) else: generated_token_logits = generated_token_logits.float() if temperature > 0.0: generated_token_logits /= temperature generated_token_logits = filter_logits( generated_token_logits, top_k=top_k, top_p=top_p) next_token_log_probs = F.softmax(generated_token_logits, dim=-1) generated_tokens = torch.multinomial( next_token_log_probs, num_samples=1).view(-1) if neox_args.is_pipe_parallel: # broadcast generated tokens to pipe parallel group src_rank = model.grid.stage_to_global(model.num_stages - 1) generated_tokens = (generated_tokens if logits is not None else torch.zeros(batch_size, dtype=torch.long).cuda()) torch.distributed.broadcast( tensor=generated_tokens, src=src_rank, group=mpu.get_pipe_parallel_group(), ) # determine if state has started for each batch item state_started = ( token_generation_start_index <= token_index_to_generate ) # check which batch items have been started # switch out padding tokens for generated tokens context_tokens[:, token_index_to_generate] = switch( context_tokens[:, token_index_to_generate].view(-1), generated_tokens, state_started, ) # determine if state has finished for each batch item state_done = (generated_tokens == eos_token_id).byte( ) & state_started.byte( ) # check which batch items produce an eos_token in the current iteration state_just_finished = (state_done & ~state_is_done).bool() state_is_done = state_is_done | state_done stop_tokens_produced = torch.zeros_like(state_is_done) for batch_idx, ctx in enumerate(context_tokens): stop_tokens_produced[batch_idx] = stop_tokens_in_completion( stop_tokens, context_tokens, batch_idx, token_index_to_generate) state_is_done = state_is_done | stop_tokens_produced token_generation_end_index[( state_started.byte() & ~state_is_done).bool()] = token_index_to_generate token_index_to_generate += 1 yield context_tokens, token_generation_start_index, token_generation_end_index, state_is_done.bool( ) if torch.all(state_is_done): break
def stream_tokens(neox_args, model, context_tokens: List[List[int]], eos_token_id: int = None, maximum_tokens: int = None, recompute: bool = False, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.0, stop_tokens=None): """ iterator producing text completions neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss model: a Megatron model. context_tokens: the prompt to complete; unpadded list of lists of tokens ids context_lengths: lengths of context tokens of dimension [batch]; the context length records for each bach item how many non-padded tokens are provided attention_mask: attention mask for megatron model. position_ids: position ids for positional encoding. eos_token_id: end of text token at which completion is terminated, even if max_tokes count has not been reached maximum_tokens: maximum number of tokens to be generated; careful! if a batch input is provided maximum_tokens specifies the maximum number of forwards. longer batch items get less generated tokens. recompute: flag indicating whether a cache is used for already forwarded tokens (true) or whether all tokens are recomputed at every iteration (false) temperature (default 0.0): exponential scaling output distribution ("higher == more risk") top_k (default 0): integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token. top_p (default 0.0): float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p. note: greedy decoding is used if temperature is 0.0, top_k is 0 and top_p is 0.0 yields: ( tokens (completions from model), token_generation_start_index (token index per batch item for the first generated token), token_generation_end_index (token index per batch item for the last generated token), logits (logits which are so far computed, zeros otherwise), is_done (flag for each bach item indicating whether an eod token was generated) ) * each iteration adds a generated token to the context_tokens * output contains both context_tokens from input and generated tokens * if batch items have different lengths, the iterator will start at the first completion and return the unchanged input context token otherwise """ model.eval() # pad batch in order to allow conversion to tensor context_tokens, context_lengths = pad_batch(copy.deepcopy(context_tokens), pad_id=neox_args.tokenizer.eod, pad_len=neox_args.seq_length) # convert to tensor and broadcast context_tokens = torch.cuda.LongTensor(context_tokens) if stop_tokens: stop_tokens = torch.cuda.LongTensor(stop_tokens) if stop_tokens.ndim == 1: stop_tokens = stop_tokens.unsqueeze(0) token_generation_start_index = torch.cuda.LongTensor(context_lengths) torch.distributed.broadcast(context_tokens, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(token_generation_start_index, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) # produce batch relevant attention_mask and position_ids context_tokens, attention_mask, position_ids = get_batch( neox_args, context_tokens) # determine the smallest context length at which first output is produced context_length = token_generation_start_index.min().item() # set variables eos_token_id = eos_token_id or neox_args.tokenizer.eod maximum_tokens = maximum_tokens or ( neox_args.seq_length - token_generation_start_index.max().item() - 1) batch_size = context_tokens.size(0) # get the context_index at which generation is to start # we start generation at the position where the smallest context ends token_index_to_generate = token_generation_start_index.min().item() first_token_index_to_generate = token_index_to_generate last_token_index_to_generate = min( neox_args.seq_length - 1, # never generate more than the model's sequence length token_index_to_generate + maximum_tokens - 1) all_logits = torch.zeros( (batch_size, neox_args.seq_length, neox_args.padded_vocab_size)) with torch.no_grad(): # initialize generation variables state_is_done = torch.zeros([batch_size]).byte().cuda() layer_past = torch.Tensor().cuda() token_generation_end_index = torch.ones([batch_size ]).long().cuda() * (-1) while token_index_to_generate <= last_token_index_to_generate: if recompute: # recompute is needed for sparse attention at the moment # because we can only forward multiples of the block size # TODO The full padded context_tokens would not need to be forwarded, adjust to multiples of block size # we need to use neox_args instead of kwargs here because deepspeed :| model_inputs = ( context_tokens, position_ids, attention_mask, torch.Tensor(), ) logits, _ = forward_model(neox_args, model, model_inputs) generated_token_logits = logits[:, token_index_to_generate - 1, :] all_logits = logits else: # not choosing recompute assumes that any number of tokens can be forwarded # this is not the case for sparse attention if token_index_to_generate == first_token_index_to_generate: tokens_to_use = context_tokens[:, :token_index_to_generate] positions_to_use = position_ids[:, : token_index_to_generate] else: tokens_to_use = context_tokens[:, token_index_to_generate - 1].view(batch_size, -1) positions_to_use = position_ids[:, token_index_to_generate - 1].view(batch_size, -1) # we have to use neox_args instead of kwargs here because deepspeed :| model_inputs = ( tokens_to_use, # input_ids positions_to_use, # position_ids attention_mask, # attention_mask layer_past, # layer_past ) logits, layer_past = forward_model(neox_args, model, model_inputs) # TODO: we are replicating computation across all machines here, which is really unecessary, # we should probably just do it on one then communicate the results? generated_token_logits = logits[:, -1].view(batch_size, -1).contiguous() if token_index_to_generate == first_token_index_to_generate: all_logits[:, : token_index_to_generate, :] = logits[:, : token_index_to_generate, :] else: all_logits[:, token_index_to_generate - 1, :] = logits[:, 0, :] # only one token will is computed # sample token id of the to be generated token if temperature == 0.0 and top_k == 0 and top_p == 0.0: generated_tokens = torch.argmax(generated_token_logits, dim=-1).view(-1) else: generated_token_logits = generated_token_logits.float() if temperature > 0.0: generated_token_logits /= temperature generated_token_logits = filter_logits(generated_token_logits, top_k=top_k, top_p=top_p) next_token_log_probs = F.softmax(generated_token_logits, dim=-1) generated_tokens = torch.multinomial(next_token_log_probs, num_samples=1).view(-1) # determine if state has started for eahc batch item state_started = token_generation_start_index <= token_index_to_generate # check which batch items have been started # switch out only padding tokens (the batch items that have been started) context_tokens[:, token_index_to_generate] = switch( context_tokens[:, token_index_to_generate].view(-1), generated_tokens, state_started) # determine if state has finished for each batch item state_done = (generated_tokens == eos_token_id).byte( ) & state_started.byte( ) # check which batch items produce an eos_token in the current iteration state_just_finished = (state_done & ~state_is_done).bool() state_is_done = state_is_done | state_done stop_tokens_produced = torch.zeros_like(state_is_done) for batch_idx, ctx in enumerate(context_tokens): stop_tokens_produced[batch_idx] = stop_tokens_in_completion( stop_tokens, context_tokens, batch_idx, token_index_to_generate) state_is_done = state_is_done | stop_tokens_produced token_generation_end_index[( state_started.byte() & ~state_is_done).bool()] = token_index_to_generate token_index_to_generate += 1 yield context_tokens, token_generation_start_index, token_generation_end_index, all_logits, state_is_done.bool( ) if torch.all(state_is_done): break
def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider): """XXX""" args = get_args() (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Number of train/valid/test samples. train_iters = args.train_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_iters * global_batch_size, eval_iters * global_batch_size, test_iters * global_batch_size] print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) # Build dataloders. train_dataloader = make_data_loader(train_ds) valid_dataloader = make_data_loader(valid_ds) test_dataloader = make_data_loader(test_ds) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 do_valid = valid_dataloader is not None and args.eval_iters > 0 do_test = test_dataloader is not None and args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = flags[0].item() args.do_valid = flags[1].item() args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = args.iteration % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'. format(train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'. format(valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator
def generate_samples_input_from_file(model): args = get_args() tokenizer = get_tokenizer() # Read the sample file and open the output file. assert args.sample_input_file is not None, \ 'sample input file is not provided.' if mpu.get_model_parallel_rank() == 0: fname = open(args.sample_input_file, "r") all_raw_text = fname.readlines() input_count = len(all_raw_text) input_pos = 0 if args.sample_output_file is None: sample_output_file = args.sample_input_file + ".out" print('could not find `sample-output-file`, setting ' 'it to {}'.format(sample_output_file)) else: sample_output_file = args.sample_output_file fname_out = open(sample_output_file, "w+") context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: raw_text = all_raw_text[input_pos] input_pos += 1 if input_pos == input_count: raw_text = "stop" if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return token_stream = get_token_stream(model, [context_tokens]) for _, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) fname_out.write("\nContext:") fname_out.write(raw_text) fname_out.write("\n\nMegatron-LM:") fname_out.write(trim_decode_tokens) fname_out.write("\n") raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters whose gradients are in fp32. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism grads = [] grads_for_norm = [] for param in parameters: grad_not_none = param.grad is not None is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if grad_not_none: grad = param.grad.detach() if grad_not_none: # Make sure the grads are in fp32 assert param.grad.type() == 'torch.cuda.FloatTensor' grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) total_norm = 0.0 # Calculate norm. if norm_type == inf: total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: if norm_type == 2.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) # Use apex's multi-tensor applier for efficiency reasons. # Multi-tensor applier takes a function and a list of list # and performs the operation on that list all in one kernel. grad_norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) # Since we will be summing across data parallel groups, # we need the pow(norm-type). total_norm = grad_norm ** norm_type else: for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm ** norm_type # Sum across all model-parallel GPUs. torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm.item() ** (1.0 / norm_type) # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: dummy_overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) return total_norm
def generate_samples_interactive(model, print_frequency=24): args = get_args() tokenizer = get_tokenizer() context_count = 0 model.eval() with torch.no_grad(): while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 if mpu.get_model_parallel_rank() == 0: os.system('clear') raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = len(context_tokens) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) terminate_runs = terminate_runs_tensor[0].item() if terminate_runs == 1: return token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() if mpu.get_model_parallel_rank() == 0 and \ counter % print_frequency == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) if mpu.get_model_parallel_rank() == 0: os.system('clear') print("\nContext:", raw_text, flush=True) trim_decode_tokens = tokenizer.detokenize( decode_tokens)[len(raw_text):] print("\nMegatron-LM:", trim_decode_tokens, flush=True) raw_text = None torch.distributed.barrier(group=mpu.get_model_parallel_group()) context_count += 1 if mpu.get_model_parallel_rank() == 0: input("\nPress any key to continue >>>")
def generate_samples_input_from_file(model): args = get_args() tokenizer = get_tokenizer() # Read the sample file and open the output file. assert args.sample_input_file is not None, \ 'sample input file is not provided.' if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank( ) == 0: fname = open(args.sample_input_file, "r") all_raw_text = fname.readlines() input_count = len(all_raw_text) input_pos = 0 if args.sample_output_file is None: sample_output_file = args.sample_input_file + ".out" print('`sample-output-file` not specified, setting ' 'it to {}'.format(sample_output_file)) else: sample_output_file = args.sample_output_file fname_out = open(sample_output_file, "w+") context_count = 0 model.eval() with torch.no_grad(): while True: terminate_runs = 0 raw_text_len = 0 if mpu.is_pipeline_first_stage() \ and mpu.get_tensor_model_parallel_rank() == 0: raw_text = all_raw_text[input_pos] input_pos += 1 if input_pos == input_count: raw_text = "stop" raw_text_len = len(raw_text) if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = 0 input_info = [terminate_runs, raw_text_len, context_length] input_info_tensor = torch.cuda.LongTensor(input_info) torch.distributed.all_reduce(input_info_tensor, group=mpu.get_model_parallel_group()) terminate_runs = input_info_tensor[0].item() raw_text_len = input_info_tensor[1].item() context_length = input_info_tensor[2].item() if terminate_runs == 1: return # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if mpu.get_tensor_model_parallel_rank() == 0 \ and args.pipeline_model_parallel_size > 1: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.cuda.LongTensor( context_tokens) torch.distributed.broadcast(context_tokens_tensor, src, group) else: src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.empty( context_length, dtype=torch.int64, device=torch.device("cuda")) torch.distributed.broadcast(context_tokens_tensor, src, group) context_tokens = context_tokens_tensor.cpu().numpy( ).tolist() token_stream = get_token_stream(model, [context_tokens]) for _, decode_tokens in enumerate(token_stream): pass if mpu.get_tensor_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage(): os.system('clear') print("\nContext:", raw_text, flush=True) fname_out.write("\nContext:") fname_out.write(raw_text) decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] print("\nMegatron-LM:", trim_decode_tokens, flush=True) fname_out.write("\n\nMegatron-LM:") fname_out.write(trim_decode_tokens) fname_out.write("\n") raw_text = None context_count += 1
def generate_samples_interactive(model, print_frequency=24): args = get_args() tokenizer = get_tokenizer() context_count = 0 model.eval() with torch.no_grad(): while True: terminate_runs = 0 raw_text_len = 0 if mpu.is_pipeline_first_stage() \ and mpu.get_tensor_model_parallel_rank() == 0: os.system('clear') raw_text = input("\nContext prompt (stop to exit) >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text_len = len(raw_text) if "stop" in raw_text: terminate_runs = 1 else: context_tokens = tokenizer.tokenize(raw_text) context_length = len(context_tokens) if context_length >= (args.seq_length // 2): print("\nContext length", context_length, "\nPlease give smaller context (half of the " "sequence length)!", flush=True) continue else: context_tokens = tokenizer.tokenize("EMPTY TEXT") context_length = 0 input_info = [terminate_runs, raw_text_len, context_length] input_info_tensor = torch.cuda.LongTensor(input_info) torch.distributed.all_reduce(input_info_tensor, group=mpu.get_model_parallel_group()) terminate_runs = input_info_tensor[0].item() raw_text_len = input_info_tensor[1].item() context_length = input_info_tensor[2].item() if terminate_runs == 1: return # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if mpu.get_tensor_model_parallel_rank() == 0 \ and args.pipeline_model_parallel_size > 1: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.cuda.LongTensor( context_tokens) torch.distributed.broadcast(context_tokens_tensor, src, group) else: src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.empty( context_length, dtype=torch.int64, device=torch.device("cuda")) torch.distributed.broadcast(context_tokens_tensor, src, group) context_tokens = context_tokens_tensor.cpu().numpy( ).tolist() token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter % print_frequency != 0 \ or mpu.get_tensor_model_parallel_rank() != 0 \ or not mpu.is_pipeline_first_stage(): continue os.system('clear') print("\nContext:", raw_text, flush=True) decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] print("\nMegatron-LM:", trim_decode_tokens, flush=True) if mpu.is_pipeline_first_stage() \ and mpu.get_tensor_model_parallel_rank() == 0: os.system('clear') print("\nContext:", raw_text, flush=True) if not isinstance(decode_tokens, list): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] print("\nMegatron-LM:", trim_decode_tokens, flush=True) input("\nPress Enter to continue >>>") raw_text = None context_count += 1
def build_train_valid_test_data_iterators(neox_args): """XXX""" (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Ensure only the first/last pipeline stages have data loaders if neox_args.is_pipe_parallel: is_first_stage = mpu.get_pipe_parallel_rank() == 0 is_last_stage = mpu.get_pipe_parallel_rank( ) == mpu.get_pipe_parallel_world_size() - 1 pipe_load = is_first_stage or is_last_stage else: pipe_load = True # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0 and pipe_load: # Number of train/valid/test samples. train_iters = neox_args.train_iters eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters test_iters = neox_args.eval_iters train_val_test_num_samples = [ train_iters * neox_args.train_batch_size, eval_iters * neox_args.train_batch_size, test_iters * neox_args.train_batch_size ] if neox_args.train_data_paths: # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset train_weights, train_num_samples = get_normalized_weights_and_num_samples( neox_args.train_data_weights, train_val_test_num_samples[0]) valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( neox_args.valid_data_weights, train_val_test_num_samples[1]) test_weights, test_num_samples = get_normalized_weights_and_num_samples( neox_args.test_data_weights, train_val_test_num_samples[2]) # build individual datasets train_datasets, valid_datasets, test_datasets = build_weighted_datasets(neox_args, train_num_samples, valid_num_samples, test_num_samples, train_weights, valid_weights, test_weights, \ build_index_mappings=not neox_args.weight_by_num_documents) if neox_args.weight_by_num_documents: # gets the number of documents in each datapath get_num_docs_list = lambda datasets: [ dataset.indexed_dataset.sizes.shape[0] for dataset in datasets ] train_num_docs, valid_num_docs, test_num_docs = get_num_docs_list( train_datasets), get_num_docs_list( valid_datasets), get_num_docs_list(test_datasets) # builds weights according to alpha + the number of docs fn = partial(weights_by_num_docs, alpha=neox_args.weighted_sampler_alpha) train_weights, valid_weights, test_weights = fn( train_num_docs), fn(valid_num_docs), fn(test_num_docs) train_weights, train_num_samples = get_normalized_weights_and_num_samples( train_weights, train_val_test_num_samples[0]) valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( valid_weights, train_val_test_num_samples[1]) test_weights, test_num_samples = get_normalized_weights_and_num_samples( test_weights, train_val_test_num_samples[2]) # rebuild datasets weighted according to new weights train_datasets, valid_datasets, test_datasets = build_weighted_datasets( neox_args, train_num_samples, valid_num_samples, test_num_samples, train_weights, valid_weights, test_weights) if train_datasets: train_ds = BlendableDataset(train_datasets, train_weights) if valid_datasets: valid_ds = BlendableDataset(valid_datasets, valid_weights) if test_datasets: test_ds = BlendableDataset(test_datasets, test_weights) else: # when just data_path is provided # split dataset into train, valid and test from data_path train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=neox_args.data_path, data_impl=neox_args.data_impl, splits_string=neox_args.split, train_valid_test_num_samples=train_val_test_num_samples, seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup)) # Build dataloders. train_dataloader = make_data_loader(train_ds, neox_args=neox_args) valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args) test_dataloader = make_data_loader(test_ds, neox_args=neox_args) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and neox_args.train_iters > 0 do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 do_test = test_dataloader is not None and neox_args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor( [int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. if neox_args.is_pipe_parallel: # Only first/last pipeline stages have data loaders, so pipeline parallelism should # broadcast globally instead of just the model parallel group. torch.distributed.broadcast(flags, src=0) else: torch.distributed.broadcast(flags, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: train_dataloader.batch_sampler.start_iter = (neox_args.iteration * neox_args.gradient_accumulation_steps) % \ len(train_dataloader) print_rank_0('setting training data start iteration to {}'.format( train_dataloader.batch_sampler.start_iter)) if valid_dataloader is not None: start_iter_val = ((neox_args.iteration * neox_args.gradient_accumulation_steps) // neox_args.eval_interval) * \ neox_args.eval_iters valid_dataloader.batch_sampler.start_iter = start_iter_val % \ len(valid_dataloader) print_rank_0('setting validation data start iteration to {}'.format( valid_dataloader.batch_sampler.start_iter)) # Build iterators. if train_dataloader is not None: train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: test_data_iterator = iter(test_dataloader) else: test_data_iterator = None return train_data_iterator, valid_data_iterator, test_data_iterator
def restore_weights(self, restore_path: str): """Restores module/model's weights. For model parallel checkpoints the directory structure should be restore_path/mp_rank_0X/model_optim_rng.pt Args: restore_path (str): restore_path should a file or a directory if using model parallel """ self._restore_path = restore_path if os.path.isfile(restore_path): logging.info( f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism' ) state_dict = torch.load(restore_path, map_location='cpu') if 'checkpoint_version' in state_dict: if state_dict['checkpoint_version'] is not None: set_checkpoint_version(state_dict['checkpoint_version']) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict( state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) logging.info(f"weights restored from {restore_path}") elif os.path.isdir(restore_path): # TODO: need to refactor this so we're not repeating code # need model parallel groups to restore model parallel checkpoints if model_parallel_is_initialized(): model_parallel_rank = torch.distributed.get_rank( group=get_model_parallel_group()) mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt' logging.info( f'Restoring model parallel checkpoint from: {mp_restore_path}' ) state_dict = torch.load(mp_restore_path, map_location='cpu') if 'checkpoint_version' in state_dict: if state_dict['checkpoint_version'] is not None: set_checkpoint_version( state_dict['checkpoint_version']) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict( state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) else: logging.info( f'torch.distributed not initialized yet. Will not restore model parallel checkpoint' ) else: logging.error( f'restore_path: {restore_path} must be a file or directory.')