def save_ds_checkpoint(iteration, model, neox_args): """Save a model checkpoint.""" sd = { 'iteration': iteration, 'args': { 'num_layers': neox_args.num_layers, 'hidden_size': neox_args.hidden_size, 'num_attention_heads': neox_args.num_attention_heads, 'max_position_embeddings': neox_args.max_position_embeddings, 'make_vocab_size_divisible_by': neox_args.make_vocab_size_divisible_by, 'padded_vocab_size': neox_args.padded_vocab_size, 'tokenizer_type': neox_args.tokenizer_type, 'model_parallel_size': neox_args.model_parallel_size } } # rng states. if not neox_args.no_save_rng: sd['random_rng_state'] = random.getstate() sd['np_rng_state'] = np.random.get_state() sd['torch_rng_state'] = torch.get_rng_state() sd['cuda_rng_state'] = torch.cuda.get_rng_state() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() if neox_args.checkpoint_validation_with_forward_pass: logits = do_forward_pass(neox_args=neox_args, model=model) sd['checkpoint_validation_logits'] = logits model.save_checkpoint(neox_args.save, client_state=sd)
def save_ds_checkpoint(iteration, model, args): """Save a model checkpoint.""" sd = {} sd['iteration'] = iteration sd['tokens'] = args.tokens sd['checkpoint_version'] = 2.0 sd['args'] = args # rng states. if not args.no_save_rng: sd['random_rng_state'] = random.getstate() sd['np_rng_state'] = np.random.get_state() sd['torch_rng_state'] = torch.get_rng_state() sd['cuda_rng_state'] = torch.cuda.get_rng_state() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() #megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict #state_dict is used by deepspeed for module saving so it needs to point to the right function original_state_dict = model.module.state_dict model.module.state_dict = model.module.state_dict_for_save_checkpoint try: model.save_checkpoint(args.save, client_state=sd) finally: model.module.state_dict = original_state_dict
def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() if args.deepspeed: save_ds_checkpoint(iteration, model, args) else: # Only rank zero of the data parallel writes to the disk. if isinstance(model, torchDDP): model = model.module if mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: state_dict['optimizer'] = optimizer.state_dict() if lr_scheduler is not None: state_dict['lr_scheduler'] = lr_scheduler.state_dict() # RNG states. if not args.no_save_rng: state_dict['random_rng_state'] = random.getstate() state_dict['np_rng_state'] = np.random.get_state() state_dict['torch_rng_state'] = torch.get_rng_state() state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() state_dict['rng_tracker_states'] \ = mpu.get_cuda_rng_tracker().get_states() # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) print( 'global rank {} is saving checkpoint at iteration {:7d} to {}'. format(torch.distributed.get_rank(), iteration, checkpoint_name)) ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) print(' successfully saved {}'.format(checkpoint_name)) # Wait so everyone is done (necessary) torch.distributed.barrier() # And update the latest iteration if torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, 'w') as f: f.write(str(iteration)) # Wait so everyone is done (necessary) torch.distributed.barrier() if args.keep_last_n_checkpoints is not None: delete_old_checkpoints(args.save, args.keep_last_n_checkpoints) # Wait so everyone is done (not necessary) torch.distributed.barrier()
def _get_attention_probs(self, attention_scores): """Attention probabilies with dropout. The output has the size [b, np, s, s]. """ # Attention probabilities. [b, np, s, s] if self.apply_query_key_layer_scaling: attention_scores = attention_scores * self.layer_number attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) return attention_probs
def save_ds_checkpoint(iteration, model, neox_args): """Save a model checkpoint.""" sd = { "iteration": iteration, "args": { "num_layers": neox_args.num_layers, "hidden_size": neox_args.hidden_size, "num_attention_heads": neox_args.num_attention_heads, "max_position_embeddings": neox_args.max_position_embeddings, "make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by, "padded_vocab_size": neox_args.padded_vocab_size, "tokenizer_type": neox_args.tokenizer_type, "model_parallel_size": neox_args.model_parallel_size, }, } # rng states. if not neox_args.no_save_rng: sd["random_rng_state"] = random.getstate() sd["np_rng_state"] = np.random.get_state() sd["torch_rng_state"] = torch.get_rng_state() sd["cuda_rng_state"] = torch.cuda.get_rng_state() sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states() if neox_args.checkpoint_validation_with_forward_pass: logits = do_forward_pass(neox_args=neox_args, model=model) sd['checkpoint_validation_logits'] = logits # checkpoint folder name tag = f"global_step{iteration}" # save checkpoint model.save_checkpoint(neox_args.save, tag=tag, client_state=sd) # save config files if torch.distributed.get_rank( ) == 0 and neox_args.config_files is not None: configs_directory = os.path.join(neox_args.save, tag, "configs") os.makedirs(configs_directory, exist_ok=True) for config_filename, config_data in neox_args.config_files.items(): with open(os.path.join(configs_directory, config_filename), "w") as f: f.write(config_data)
def save_ds_checkpoint(iteration, model, args): """Save a model checkpoint.""" sd = {} sd['iteration'] = iteration # rng states. if not args.no_save_rng: sd['random_rng_state'] = random.getstate() sd['np_rng_state'] = np.random.get_state() sd['torch_rng_state'] = torch.get_rng_state() sd['cuda_rng_state'] = torch.cuda.get_rng_state() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() if args.pipe_parallel_size == 0: #megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict #state_dict is used by deepspeed for module saving so it needs to point to the right function model.module.state_dict = model.module.state_dict_for_save_checkpoint else: # Pipeline parallelism manages its own state_dict. pass model.save_checkpoint(args.save, client_state=sd)
def get_rng_state(): """ collect rng state across data parallel ranks """ args = get_args() rng_state = { 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(), 'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states() } rng_state_list = None if torch.distributed.is_initialized() and \ mpu.get_data_parallel_world_size() > 1 and \ args.data_parallel_random_init: rng_state_list = \ [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather_object( rng_state_list, rng_state, group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] return rng_state_list
def load_checkpoint(model, optimizer, lr_scheduler): """Load a model checkpoint and return the iteration.""" args = get_args() if isinstance(model, torchDDP): model = model.module # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(args.load) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_0('WARNING: could not find the metadata file {} '.format( tracker_filename)) print_rank_0(' will not load any checkpoints and will start from ' 'random') return 0 # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, 'r') as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == 'release' if not release: print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( tracker_filename)) sys.exit() assert iteration > 0 or release, 'error parsing metadata file {}'.format( tracker_filename) # Checkpoint. checkpoint_name = get_checkpoint_name(args.load, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. try: state_dict = torch.load(checkpoint_name, map_location='cpu') except ModuleNotFoundError: # For backward compatibility. print_rank_0(' > deserializing using the old code structure ...') sys.modules['fp16.loss_scaler'] = sys.modules[ 'megatron.fp16.loss_scaler'] state_dict = torch.load(checkpoint_name, map_location='cpu') sys.modules.pop('fp16.loss_scaler', None) except BaseException: print_rank_0('could not load the checkpoint') sys.exit() # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] except KeyError: print_rank_0('A metadata file exists but unable to load ' 'iteration from checkpoint {}, exiting'.format( checkpoint_name)) sys.exit() # Check arguments. if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) else: print_rank_0('could not find arguments in the checkpoint ...') # Model. model.load_state_dict(state_dict['model']) # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict['lr_scheduler']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the optimizer state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): """Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of parameters and buffers in model. """ args = get_args() load_dir = getattr(args, load_arg) model = utils.unwrap_model(model) # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_0('WARNING: could not find the metadata file {} '.format( tracker_filename)) print_rank_0(' will not load any checkpoints and will start from ' 'random') return 0 # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, 'r') as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == 'release' if not release: print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( tracker_filename)) sys.exit() assert iteration > 0 or release, 'error parsing metadata file {}'.format( tracker_filename) # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, iteration, release) print_rank_0( f' loading checkpoint from {args.load} at iteration {iteration}') # Load the checkpoint. try: state_dict = torch.load(checkpoint_name, map_location='cpu') except ModuleNotFoundError: from megatron.fp16_deprecated import loss_scaler # For backward compatibility. print_rank_0(' > deserializing using the old code structure ...') sys.modules['fp16.loss_scaler'] = sys.modules[ 'megatron.fp16_deprecated.loss_scaler'] sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ 'megatron.fp16_deprecated.loss_scaler'] state_dict = torch.load(checkpoint_name, map_location='cpu') sys.modules.pop('fp16.loss_scaler', None) sys.modules.pop('megatron.fp16.loss_scaler', None) except BaseException: print_rank_0('could not load the checkpoint') sys.exit() # set checkpoint version set_checkpoint_version(state_dict.get('checkpoint_version', 0)) # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] except KeyError: print_rank_0('A metadata file exists but unable to load ' 'iteration from checkpoint {}, exiting'.format( checkpoint_name)) sys.exit() # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') # Model. if len(model) == 1: model[0].load_state_dict(state_dict['model'], strict=strict) else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) model[i].load_state_dict(state_dict['model%d' % i], strict=strict) # Fix up query/key/value matrix ordering if needed checkpoint_version = get_checkpoint_version() print_rank_0(f' checkpoint version {checkpoint_version}') fix_query_key_value_ordering(model, checkpoint_version) # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict['lr_scheduler']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state']) # Check for empty states array if not state_dict['rng_tracker_states']: raise KeyError mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load rng state from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the rng state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() # Some utilities want to load a checkpoint without distributed being initialized if torch.distributed.is_initialized(): torch.distributed.barrier() print_rank_0(f' successfully loaded checkpoint from {args.load} ' f'at iteration {iteration}') return iteration
def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() # Only rank zero of the data parallel writes to the disk. model = utils.unwrap_model(model) print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank( ) == 0: # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration if len(model) == 1: state_dict['model'] = model[0].state_dict_for_save_checkpoint() else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: state_dict['optimizer'] = optimizer.state_dict() if lr_scheduler is not None: state_dict['lr_scheduler'] = lr_scheduler.state_dict() # RNG states. if not args.no_save_rng: state_dict['random_rng_state'] = random.getstate() state_dict['np_rng_state'] = np.random.get_state() state_dict['torch_rng_state'] = torch.get_rng_state() state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() state_dict['rng_tracker_states'] \ = mpu.get_cuda_rng_tracker().get_states() # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) # Wait so everyone is done (necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() print_rank_0( ' successfully saved checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) # And update the latest iteration if not torch.distributed.is_initialized() or torch.distributed.get_rank( ) == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, 'w') as f: f.write(str(iteration)) # Wait so everyone is done (not necessary) if torch.distributed.is_initialized(): torch.distributed.barrier()
def forward(self, hidden_states, attention_mask, rotary_pos_emb=None, layer_past=None, get_key_value=False): # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1.0: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) if exists(rotary_pos_emb): query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, rotary_pos_emb) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) if not self.sparse: # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocating result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm(matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) attention_scores += rpe # [1, np, sq, sk] # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) else: # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] query_layer, key_layer, value_layer = map(lambda t: t.permute(1, 2, 0, 3).contiguous(), (query_layer, key_layer, value_layer)) # output shape [b, np(heads), sq, hn] attn_mask = attention_mask.to(query_layer.dtype) * -10000 if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) else: rpe = None context_layer = self.sparse_attn(query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias
def attention(self, query_layer, key_layer, value_layer, layer_past, attention_mask): # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = ( query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0), ) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocating result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device(), ) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if self.use_cache: with torch.no_grad(): attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) attention_scores += rpe # [1, np, sq, sk] if self.pos_emb == "alibi": attention_scores = self.alibi_embed(attention_scores) # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = ( value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3), ) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) return context_layer
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [s, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [s, b, hp] --> [s, b, 3 * hp] mixed_x_layer, _ = self.query_key_value(hidden_states) # [s, b, 3 * hp] --> [s, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [s, b, np, 3 * hn] --> 3 [s, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, s, s] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [s, b, np, hn] -> [s, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, s, s] matmul_result = torch.empty(output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, s, s] matmul_result = torch.baddbmm( matmul_result, query_layer.transpose(0, 1), # [b * np, s, hn] key_layer.transpose(0, 1).transpose(1, 2), #[b * np, hn, s] beta=0.0, alpha=(1.0 / self.norm_factor)) # change view to [b, np, s, s] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, s, s] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores. size(3)] # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, s, s] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [s, b, hp] # ========================= # value_layer -> context layer. # [s, b, np, hn] --> [b, np, s, hn] # context layer shape: [b, np, s, hn] output_size = (value_layer.size(1), value_layer.size(2), value_layer.size(0), value_layer.size(3)) # change view [s, b * np, hn] value_layer = value_layer.view(output_size[2], output_size[0] * output_size[1], -1) # change view [b * np, s, s] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, s, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, s, hn] context_layer = context_layer.view(*output_size) # [b, np, s, hn] --> [s, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [s, b, np, hn] --> [s, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [s, b, h] # ================= output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias
def load_checkpoint(neox_args, model, optimizer, lr_scheduler, inference=False): """Load a model checkpoint and return the iteration.""" if neox_args.deepspeed: load_optim_and_scheduler = not neox_args.no_load_optim # TODO: These should be configured by separate args checkpoint_name, state_dict = model.load_checkpoint(neox_args.load, load_optimizer_states=load_optim_and_scheduler, load_lr_scheduler_states=load_optim_and_scheduler) if checkpoint_name is None: if mpu.get_data_parallel_rank() == 0: print("Unable to load checkpoint.") return 0 # iteration 0, if not checkpoint loaded else: raise ValueError('Must be using deepspeed to use neox') # Set iteration. if neox_args.finetune: iteration = 0 else: iteration = state_dict.get('iteration') or state_dict.get("total_iters") # total_iters backward compatible with older checkpoints if iteration is None: raise ValueError('Unable to load iteration from checkpoint {}, exiting'.format(checkpoint_name)) # Check arguments. if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args) print_rank_0(' > validated currently set args with arguments in the checkpoint ...') else: print_rank_0(' > could not find arguments in the checkpoint for validation...') # Check loaded checkpoint with forward pass if neox_args.checkpoint_validation_with_forward_pass: if "checkpoint_validation_logits" in state_dict: check_forward_pass( neox_args=neox_args, model=model, checkpoint_logits=state_dict["checkpoint_validation_logits"], inference=inference ) print_rank_0(' > validated loaded checkpoint with forward pass ...') else: if mpu.get_data_parallel_rank() == 0: print(' > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}'.format(checkpoint_name)) # rng states. if not neox_args.finetune and not neox_args.no_load_rng: try: random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the optimizer state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): """Load a model checkpoint and return the iteration.""" from megatron import get_args from megatron import mpu from megatron import print_rank_last from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import set_checkpoint_version from megatron.checkpointing import check_checkpoint_args from megatron.checkpointing import update_num_microbatches if mpu.get_data_parallel_rank() == 0: # at dp rank 0, we still follow the native load_checkpoint by megatron from megatron.checkpointing import load_checkpoint as load_checkpoint_native return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg) args = get_args() load_dir = getattr(args, load_arg) if isinstance(model, DistributedDataParallel): model = model.module # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_last("WARNING: could not find the metadata file {} ".format( tracker_filename)) print_rank_last( " will not load any checkpoints and will start from " "random") return 0 # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, "r") as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == "release" if not release: print_rank_last( "ERROR: Invalid metadata file {}. Exiting".format( tracker_filename)) sys.exit() assert iteration > 0 or release, "error parsing metadata file {}".format( tracker_filename) # Checkpoint. checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0) checkpoint_name_local = get_fmoe_checkpoint_name( load_dir, iteration, release, mpu.get_data_parallel_rank()) print_rank_last( " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later" .format( checkpoint_name_rank0, mpu.get_data_parallel_rank(), checkpoint_name_local, iteration, )) # Load the checkpoint. def load_state_dict(checkpoint_name): try: state_dict = torch.load(checkpoint_name, map_location="cpu") except ModuleNotFoundError: from megatron.fp16_deprecated import loss_scaler # For backward compatibility. print_rank_last( " > deserializing using the old code structure ...") sys.modules["fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] sys.modules["megatron.fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] state_dict = torch.load(checkpoint_name, map_location="cpu") sys.modules.pop("fp16.loss_scaler", None) sys.modules.pop("megatron.fp16.loss_scaler", None) except BaseException: print_rank_last("could not load the checkpoint") sys.exit() return state_dict state_dict_rank0 = load_state_dict(checkpoint_name_rank0) state_dict_local = load_state_dict(checkpoint_name_local) state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16) # set checkpoint version set_checkpoint_version(state_dict.get("checkpoint_version", 0)) # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict["iteration"] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict["total_iters"] except KeyError: print_rank_last("A metadata file exists but unable to load " "iteration from checkpoint {}, exiting".format( checkpoint_name_local)) sys.exit() # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 if "args" in state_dict: checkpoint_args = state_dict["args"] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, "consumed_train_samples", 0) update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr(checkpoint_args, "consumed_valid_samples", 0) else: print_rank_last("could not find arguments in the checkpoint ...") # Model. model.load_state_dict(state_dict["model"]) # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict["optimizer"]) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) except KeyError: print_rank_last("Unable to load optimizer from checkpoint {}. " "Specify --no-load-optim or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name_local)) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict["random_rng_state"]) np.random.set_state(state_dict["np_rng_state"]) torch.set_rng_state(state_dict["torch_rng_state"]) torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) mpu.get_cuda_rng_tracker().set_states( state_dict["rng_tracker_states"]) except KeyError: print_rank_last("Unable to load optimizer from checkpoint {}. " "Specify --no-load-rng or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name_local)) sys.exit() torch.distributed.barrier() print_rank_last( " successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}" .format(args.load, iteration)) return iteration
def load_checkpoint(neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None): """Load a model checkpoint and return the iteration.""" if neox_args.deepspeed: load_optim_and_scheduler = ( not neox_args.no_load_optim ) # TODO: These should be configured by separate args if neox_args.finetune: load_optim_and_scheduler = False if iteration is not None: tag = f"global_step{iteration}" else: tag = None checkpoint_name, state_dict = model.load_checkpoint( neox_args.load, load_optimizer_states=load_optim_and_scheduler, load_lr_scheduler_states=load_optim_and_scheduler, tag=tag, ) if checkpoint_name is None: # if an iteration is specified, we want to raise an error here rather than # continuing silently, since we are trying to load a specific checkpoint if iteration is not None: available_checkpoints = sorted([ int(i.name.replace("global_step", "")) for i in Path(neox_args.load).glob("global_step*") ]) raise ValueError( f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}" ) if mpu.get_data_parallel_rank() == 0: print("Unable to load checkpoint.") return 0 # iteration 0, if not checkpoint loaded else: raise ValueError("Must be using deepspeed to use neox") # Set iteration. if neox_args.finetune: iteration = 0 else: iteration = state_dict.get("iteration") or state_dict.get( "total_iters" ) # total_iters backward compatible with older checkpoints if iteration is None: raise ValueError( f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting" ) # Check arguments. if "args" in state_dict: checkpoint_args = state_dict["args"] check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args) print_rank_0( " > validated currently set args with arguments in the checkpoint ..." ) else: print_rank_0( " > could not find arguments in the checkpoint for validation...") # Check loaded checkpoint with forward pass if neox_args.checkpoint_validation_with_forward_pass: if "checkpoint_validation_logits" in state_dict: check_forward_pass( neox_args=neox_args, model=model, checkpoint_logits=state_dict["checkpoint_validation_logits"], inference=inference, ) print_rank_0( " > validated loaded checkpoint with forward pass ...") else: if mpu.get_data_parallel_rank() == 0: print( " > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}" .format(checkpoint_name)) # rng states. if not neox_args.finetune and not neox_args.no_load_rng: try: random.setstate(state_dict["random_rng_state"]) np.random.set_state(state_dict["np_rng_state"]) torch.set_rng_state(state_dict["torch_rng_state"]) torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) mpu.get_cuda_rng_tracker().set_states( state_dict["rng_tracker_states"]) except KeyError: print_rank_0("Unable to load optimizer from checkpoint {}. " "Specify --no-load-rng or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name)) sys.exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(" successfully loaded {}".format(checkpoint_name)) return iteration