def _load_checkpoint(self, filename): """Helper function for loading megatron checkpoints. Args: filename (str): Path to megatron checkpoint. """ state_dict = torch.load(filename, map_location='cpu') if 'checkpoint_version' in state_dict: if state_dict['checkpoint_version'] is not None: set_checkpoint_version(state_dict['checkpoint_version']) logging.info( f"Megatron-lm checkpoint version found. Setting checkpoint_version to {state_dict['checkpoint_version']}." ) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict( state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) logging.info(f"Checkpoint loaded from from {filename}")
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ LightningModule hook that's used to restore things saved with on_save_checkpoint.""" if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder): if get_checkpoint_version(): assert ( checkpoint['checkpoint_version'] == get_checkpoint_version() ), 'checkpoint version found on_load_checkpoint different than get_checkpoint_version' else: set_checkpoint_version(checkpoint['checkpoint_version']) logging.info(f"Setting Megatron checkpoint version: {checkpoint['checkpoint_version']}") return None
def start_training(self, trainer: 'Trainer') -> None: """ PTL Hook that is called after DPP is initialized. """ if isinstance(self.lightning_module.bert_model, MegatronBertEncoder): app_state = AppState() if app_state.model_parallel_size is not None: # mpu grad clipping needs parameters to have the attribute model_parallel parameters = self.lightning_module.parameters() for p in parameters: if not hasattr(p, 'model_parallel'): p.model_parallel = False # TODO: figure out how to override clip gradients again # Update PTL trainer to use our _clip_gradients # self._trainer.accelerator_backend._clip_gradients = self._clip_gradients if get_checkpoint_version(): # Restored from .nemo, checkpoint_version will already be set pass elif trainer.resume_from_checkpoint is not None: # PTL auto-resuming, need to update checkpoint name # update path based on model parallel rank filepath = trainer.resume_from_checkpoint dirname = os.path.dirname(os.path.dirname(filepath)) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' trainer.resume_from_checkpoint = filepath logging.info( f'Resuming training from checkpoint {trainer.resume_from_checkpoint}' ) # need to set checkpoint version for megatron-lm checkpoint_version = torch.load( trainer.resume_from_checkpoint).get( 'checkpoint_version', None) if checkpoint_version is not None: set_checkpoint_version(checkpoint_version) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) else: logging.info( f"Restoring from pretrained model parallel checkpoint: {self.lightning_module.bert_model._restore_path}" ) self.lightning_module.bert_model.restore_weights( self.lightning_module.bert_model._restore_path) self.lightning_module.register_megatron_checkpoint_version() return super().start_training(trainer)
def start_training(self, trainer: 'Trainer') -> None: """ PTL Hook that is called after DPP is initialized. """ if self.lightning_module.has_megatron_encoder: app_state = AppState() if app_state.model_parallel_size is not None: # mpu grad clipping needs parameters to have the attribute model_parallel parameters = self.lightning_module.parameters() for p in parameters: if not hasattr(p, 'model_parallel'): p.model_parallel = False if get_checkpoint_version() is not None: # megatron checkpoint already restored pass elif trainer.resume_from_checkpoint is not None: # PTL auto-resuming, need to update checkpoint name # update path based on model parallel rank filepath = trainer.resume_from_checkpoint dirname = os.path.dirname(os.path.dirname(filepath)) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' trainer.resume_from_checkpoint = filepath logging.info( f'Resuming training from checkpoint {trainer.resume_from_checkpoint}' ) # need to set checkpoint version for megatron-lm checkpoint_version = torch.load( trainer.resume_from_checkpoint).get( 'checkpoint_version', None) if checkpoint_version is not None: set_checkpoint_version(checkpoint_version) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) else: self.lightning_module.restore_megatron_encoder_weights() else: if get_checkpoint_version() is not None: # megatron checkpoint already restored pass else: self.lightning_module.restore_megatron_encoder_weights() self.lightning_module.register_megatron_checkpoint_version() return super().start_training(trainer)
def setup(self, stage: str) -> None: """ PTL hook that is called after DDP is initialized. Called at the beginning of fit and test. Args: stage (str): either 'fit' or 'test' """ # TODO: implement model parallel for test stage if stage == 'fit': # set find_unused_parameters to True by default for NLP models if isinstance(self.trainer.accelerator.training_type_plugin, DDPPlugin): self.trainer.accelerator.training_type_plugin._ddp_kwargs[ 'find_unused_parameters'] = True # adds self.bert_model config to .nemo file if hasattr(self, 'bert_model') and self.bert_model is not None: self.register_bert_model() app_state = AppState() if app_state.model_parallel_size is not None: if app_state.model_parallel_group is None: self.init_model_parallel(app_state.global_rank, app_state.world_size) # mpu grad clipping needs parameters to have the attribute model_parallel parameters = self._trainer.get_model().parameters() for p in parameters: if not hasattr(p, 'model_parallel'): p.model_parallel = False # Update PTL trainer to use our configure_ddp self._trainer.accelerator_backend.ddp_plugin.configure_ddp = self.configure_ddp # Update PTL trainer to use our _clip_gradients self._trainer.accelerator_backend._clip_gradients = self._clip_gradients self._trainer.checkpoint_connector = NLPCheckpointConnector( self._trainer) # Configure checkpointing for model parallel if app_state.create_checkpoint_callback: # global rank 0 is configured by exp_manager if not is_global_rank_zero( ) and app_state.data_parallel_rank == 0: configure_checkpointing( self._trainer, app_state.log_dir, app_state.checkpoint_name, app_state.checkpoint_callback_params, ) if isinstance(self.bert_model, MegatronBertEncoder): self.bert_model.complete_lazy_init() # model parallel checkpoints need to be restored after torch.distributed is initialized if self._trainer.resume_from_checkpoint is not None: # update path based on model parallel rank filepath = self._trainer.resume_from_checkpoint dirname = os.path.dirname(os.path.dirname(filepath)) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' self._trainer.resume_from_checkpoint = filepath logging.info( f'Resuming training from checkpoint {self._trainer.resume_from_checkpoint}' ) # need to set checkpoint version for megatron-lm checkpoint_version = torch.load( self._trainer.resume_from_checkpoint).get( 'checkpoint_version', None) if checkpoint_version is not None: set_checkpoint_version(checkpoint_version) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) else: logging.info( f"Restoring from pretrained model parallel checkpoint: {self.bert_model._restore_path}" ) self.bert_model.restore_weights( self.bert_model._restore_path) logging.info( "Replacing sampler with model parallel sampler") mp_sampler = torch.utils.data.distributed.DistributedSampler( self._train_dl.dataset, num_replicas=app_state.data_parallel_size, rank=app_state.data_parallel_rank, ) mp_dl = self._trainer.replace_sampler( self._train_dl, mp_sampler) self._train_dl = mp_dl else: raise NotImplementedError( f'The BERT encoder: {self.bert_model} does not support model parallelism yet.' )
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): """Load a model checkpoint and return the iteration.""" from megatron import get_args from megatron import mpu from megatron import print_rank_last from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import set_checkpoint_version from megatron.checkpointing import check_checkpoint_args from megatron.checkpointing import update_num_microbatches if mpu.get_data_parallel_rank() == 0: # at dp rank 0, we still follow the native load_checkpoint by megatron from megatron.checkpointing import load_checkpoint as load_checkpoint_native return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg) args = get_args() load_dir = getattr(args, load_arg) if isinstance(model, DistributedDataParallel): model = model.module # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_last("WARNING: could not find the metadata file {} ".format( tracker_filename)) print_rank_last( " will not load any checkpoints and will start from " "random") return 0 # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, "r") as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == "release" if not release: print_rank_last( "ERROR: Invalid metadata file {}. Exiting".format( tracker_filename)) sys.exit() assert iteration > 0 or release, "error parsing metadata file {}".format( tracker_filename) # Checkpoint. checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0) checkpoint_name_local = get_fmoe_checkpoint_name( load_dir, iteration, release, mpu.get_data_parallel_rank()) print_rank_last( " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later" .format( checkpoint_name_rank0, mpu.get_data_parallel_rank(), checkpoint_name_local, iteration, )) # Load the checkpoint. def load_state_dict(checkpoint_name): try: state_dict = torch.load(checkpoint_name, map_location="cpu") except ModuleNotFoundError: from megatron.fp16_deprecated import loss_scaler # For backward compatibility. print_rank_last( " > deserializing using the old code structure ...") sys.modules["fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] sys.modules["megatron.fp16.loss_scaler"] = sys.modules[ "megatron.fp16_deprecated.loss_scaler"] state_dict = torch.load(checkpoint_name, map_location="cpu") sys.modules.pop("fp16.loss_scaler", None) sys.modules.pop("megatron.fp16.loss_scaler", None) except BaseException: print_rank_last("could not load the checkpoint") sys.exit() return state_dict state_dict_rank0 = load_state_dict(checkpoint_name_rank0) state_dict_local = load_state_dict(checkpoint_name_local) state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16) # set checkpoint version set_checkpoint_version(state_dict.get("checkpoint_version", 0)) # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict["iteration"] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict["total_iters"] except KeyError: print_rank_last("A metadata file exists but unable to load " "iteration from checkpoint {}, exiting".format( checkpoint_name_local)) sys.exit() # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 if "args" in state_dict: checkpoint_args = state_dict["args"] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, "consumed_train_samples", 0) update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr(checkpoint_args, "consumed_valid_samples", 0) else: print_rank_last("could not find arguments in the checkpoint ...") # Model. model.load_state_dict(state_dict["model"]) # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict["optimizer"]) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) except KeyError: print_rank_last("Unable to load optimizer from checkpoint {}. " "Specify --no-load-optim or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name_local)) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict["random_rng_state"]) np.random.set_state(state_dict["np_rng_state"]) torch.set_rng_state(state_dict["torch_rng_state"]) torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) mpu.get_cuda_rng_tracker().set_states( state_dict["rng_tracker_states"]) except KeyError: print_rank_last("Unable to load optimizer from checkpoint {}. " "Specify --no-load-rng or --finetune to prevent " "attempting to load the optimizer state, " "exiting ...".format(checkpoint_name_local)) sys.exit() torch.distributed.barrier() print_rank_last( " successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}" .format(args.load, iteration)) return iteration
def restore_from( cls, restore_path: str, override_config_path: Optional[Union[OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = True, return_config: bool = False, trainer: Trainer = None, save_restore_connector: SaveRestoreConnector = None, ): """ Restores model instance (weights and configuration) from .nemo file. Args: restore_path: path to .nemo file from which model should be instantiated override_config_path: path to a yaml config that will override the internal config file or an OmegaConf / DictConfig object representing the model config. map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise. strict: Passed to load_state_dict. Set to True by default. return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model. trainer: PyTorch Lightning trainer. Must be passed in order to use model parallel .nemo Example: ``` model = nemo.collections.nlp.models.TokenClassificationModel.restore_from('token_classification.nemo') assert isinstance(model, nemo.collections.nlp.models.TokenClassificationModel) ``` Returns: An instance of type cls or its underlying config (if return_config is set). """ if save_restore_connector is None: save_restore_connector = SaveRestoreConnector() if not os.path.exists(restore_path): raise FileNotFoundError(f"Can't find {restore_path}") app_state = AppState() app_state.model_restore_path = os.path.abspath( os.path.expanduser(restore_path)) # detect if we have a model parallel .nemo file with tempfile.TemporaryDirectory() as tmpdir: cwd = os.getcwd() os.chdir(tmpdir) # detect if model parallel from tarfile tar = tarfile.open(app_state.model_restore_path, "r:gz") names = tar.getnames() mp_ranks = [] for name in names: if 'mp_rank' in name: mp_ranks.append(name) if mp_ranks: app_state.model_parallel_size = len( mp_ranks ) // 2 # directory and file are included in getnames() # get checkpoint version checkpoint_version_member = None for member in tar.getmembers(): if 'megatron_checkpoint_version.json' in member.name: checkpoint_version_member = member tar.extract(checkpoint_version_member, tmpdir) with open(checkpoint_version_member.name, 'r') as f: checkpoint_version = json.load(f).get( 'checkpoint_version', None) logging.info( (f'Detected model parallel .nemo file: {restore_path}. ' f'Assuming megatron model parallelism with ' f'model_parallel_size: {app_state.model_parallel_size} ' f'and checkpoint version: {checkpoint_version}')) tar.close() os.chdir(cwd) if app_state.model_parallel_size is not None: if not isinstance(trainer, Trainer): raise ValueError( "trainer must be a PyTorch Lightning Trainer to restore model parallel .nemo files." ) if checkpoint_version is None: raise ValueError( "Restoring from megatron model parallel .nemo but could not find megatron checkpoint version." ) else: logging.info( f"Setting megatron checkpoint version: {checkpoint_version}" ) set_checkpoint_version(checkpoint_version) app_state.world_size = trainer.num_gpus * trainer.num_nodes if trainer.local_rank is not None: app_state.local_rank = trainer.local_rank else: raise ValueError( "trainer.local_rank is None. local_rank needed to restore model parallel models." ) model_parallel_rank = compute_model_parallel_rank( trainer.local_rank, app_state.model_parallel_size) app_state.model_parallel_rank = model_parallel_rank cls.update_save_restore_connector(save_restore_connector) restored_model = cls._save_restore_connector.restore_from( cls, app_state.model_restore_path, override_config_path, map_location, strict, return_config) restored_model.set_trainer(trainer) return restored_model else: return super().restore_from( app_state.model_restore_path, override_config_path, map_location, strict, return_config, save_restore_connector=save_restore_connector, )
def restore_weights(self, restore_path: str): """Restores module/model's weights. For model parallel checkpoints the directory structure should be restore_path/mp_rank_0X/model_optim_rng.pt Args: restore_path (str): restore_path should a file or a directory if using model parallel """ self._restore_path = restore_path if os.path.isfile(restore_path): logging.info( f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism' ) state_dict = torch.load(restore_path, map_location='cpu') if 'checkpoint_version' in state_dict: if state_dict['checkpoint_version'] is not None: set_checkpoint_version(state_dict['checkpoint_version']) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict( state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) logging.info(f"weights restored from {restore_path}") elif os.path.isdir(restore_path): # TODO: need to refactor this so we're not repeating code # need model parallel groups to restore model parallel checkpoints if model_parallel_is_initialized(): model_parallel_rank = torch.distributed.get_rank( group=get_model_parallel_group()) mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt' logging.info( f'Restoring model parallel checkpoint from: {mp_restore_path}' ) state_dict = torch.load(mp_restore_path, map_location='cpu') if 'checkpoint_version' in state_dict: if state_dict['checkpoint_version'] is not None: set_checkpoint_version( state_dict['checkpoint_version']) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict( state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) else: logging.info( f'torch.distributed not initialized yet. Will not restore model parallel checkpoint' ) else: logging.error( f'restore_path: {restore_path} must be a file or directory.')