def load_from_metrics(cls, weights_path, tags_csv, overwrite_hparams=None): overwrite_hparams = overwrite_hparams or {} hparams = load_hparams_from_tags_csv(tags_csv) hparams.__dict__["logger"] = eval( hparams.__dict__.get("logger", "None")) if (str(hparams.sampling_probs) == "nan" or str(hparams.sampling_probs) == "None" or len(hparams.sampling_probs) == 0): hparams.__dict__["sampling_probs"] = None if (str(hparams.audio_file) == "nan" or str(hparams.audio_file) == "None" or len(hparams.audio_file) == 0): hparams.__dict__["audio_file"] = None hparams.__setattr__("on_gpu", False) hparams.__dict__.update(overwrite_hparams) # load on CPU only to avoid OOM issues # then its up to user to put back on GPUs checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) # load the state_dict on the model automatically model = cls(hparams) model.load_state_dict(checkpoint["state_dict"]) optimizer = model.configure_optimizers() optimizer.load_state_dict(checkpoint["optimizer_states"][0]) # give model a chance to load something model.on_load_checkpoint(checkpoint) return model
def test_loading_meta_tags(tmpdir): """ test for backward compatibility to meta_tags.csv """ tutils.reset_seed() hparams = EvalModelTemplate.get_default_hparams() # save tags logger = tutils.get_default_logger(tmpdir) logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0)) logger.log_hyperparams(hparams) logger.save() # load hparams path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE) hparams = load_hparams_from_yaml(hparams_path) # save as legacy meta_tags.csv tags_path = os.path.join(path_expt_dir, 'meta_tags.csv') save_hparams_to_tags_csv(tags_path, hparams) tags = load_hparams_from_tags_csv(tags_path) assert hparams == tags
def load_from_metrics(cls, weights_path, tags_csv, map_location=None): r""" You should use `load_from_checkpoint` instead! However, if your .ckpt weights don't have the hyperparameters saved, use this method to pass in a .csv with the hparams you'd like to use. These will be converted into a argparse.Namespace and passed into your LightningModule for use. Args: weights_path (str): Path to a PyTorch checkpoint tags_csv (str): Path to a .csv with two columns (key, value) as in this Example:: key,value drop_prob,0.2 batch_size,32 map_location (dict | str | torch.device | function): If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup (example: {'cuda:1':'cuda:0'}). The behaviour is the same as in `torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_. Return: LightningModule with loaded weights and hyperparameters (if available). Example ------- .. code-block:: python pretrained_model = MyLightningModule.load_from_metrics( weights_path='/path/to/pytorch_checkpoint.ckpt', tags_csv='/path/to/hparams_file.csv', on_gpu=True, map_location=None ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x) """ hparams = load_hparams_from_tags_csv(tags_csv) hparams.__setattr__('on_gpu', False) if map_location is not None: checkpoint = torch.load(weights_path, map_location=map_location) else: checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) # add the hparams from csv file to checkpoint checkpoint['hparams'] = vars(hparams) model = cls._load_model_state(checkpoint) return model
def load_from_checkpoint( cls, checkpoint_path: str, map_location: Any = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs, ): """ Loads ModelPT from checkpoint, with some maintenance of restoration. For documentation, please refer to LightningModule.load_from_checkpoin() documentation. """ checkpoint = None try: cls._set_model_restore_state(is_being_restored=True) # TODO: replace with proper PTL API with pl_legacy_patch(): if map_location is not None: checkpoint = pl_load(checkpoint_path, map_location=map_location) else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) if hparams_file is not None: extension = hparams_file.split(".")[-1] if extension.lower() == "csv": hparams = load_hparams_from_tags_csv(hparams_file) elif extension.lower() in ("yml", "yaml"): hparams = load_hparams_from_yaml(hparams_file) else: raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") hparams["on_gpu"] = False # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams # for past checkpoint need to add the new key if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} # override the hparams with values that were passed in # TODO: can we do this without overriding? config_kwargs = kwargs.copy() if 'trainer' in config_kwargs: config_kwargs.pop('trainer') checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config_kwargs) if 'cfg' in kwargs: model = cls._load_model_state(checkpoint, strict=strict, **kwargs) else: model = cls._load_model_state( checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs ) checkpoint = model finally: cls._set_model_restore_state(is_being_restored=False) return checkpoint
def load_from_checkpoint( cls, checkpoint_path: str, map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, tags_csv: Optional[str] = None, ) -> 'LightningModule': r""" Primary way of loading model from a checkpoint. When Lightning saves a checkpoint it stores the hyperparameters in the checkpoint if you initialized your LightningModule with an argument called `hparams` which is a Namespace (output of using argparse to parse command line arguments). Example ------- .. code-block:: python from argparse import Namespace hparams = Namespace(**{'learning_rate': 0.1}) model = MyModel(hparams) class MyModel(LightningModule): def __init__(self, hparams): self.learning_rate = hparams.learning_rate Args: checkpoint_path: Path to checkpoint. map_location: If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in `torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_. tags_csv: Optional path to a .csv file with two columns (key, value) as in this example:: key,value drop_prob,0.2 batch_size,32 You most likely won't need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don't have the hyperparameters saved, use this method to pass in a .csv file with the hparams you'd like to use. These will be converted into a argparse.Namespace and passed into your LightningModule for use. Return: LightningModule with loaded weights and hyperparameters (if available). Example ------- .. code-block:: python # load weights without mapping ... MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', tags_csv='/path/to/hparams_file.csv' ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x) """ if map_location is not None: checkpoint = torch.load(checkpoint_path, map_location=map_location) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) if tags_csv is not None: # add the hparams from csv file to checkpoint hparams = load_hparams_from_tags_csv(tags_csv) hparams.__setattr__('on_gpu', False) checkpoint['hparams'] = vars(hparams) model = cls._load_model_state(checkpoint) return model
def load_from_checkpoint( cls, checkpoint_path: str, map_location: Any = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs, ): """ Loads Megatron_LM checkpoints, convert it, with some maintenance of restoration. For documentation, please refer to LightningModule.load_from_checkpoin() documentation. """ checkpoint = None try: cls._set_model_restore_state(is_being_restored=True) # TODO: replace with proper PTL API with pl_legacy_patch(): if map_location is not None: old_checkpoint = pl_load(checkpoint_path, map_location=map_location) else: old_checkpoint = pl_load( checkpoint_path, map_location=lambda storage, loc: storage) total_params = [0] checkpoint = OrderedDict() checkpoint['state_dict'] = OrderedDict() parse_weights(old_checkpoint['model'], "", total_params, checkpoint['state_dict'], translator=kwargs['translator']) print('converted {:.2f}M parameters'.format(total_params[0] / 1e6)) if hparams_file is not None: extension = hparams_file.split(".")[-1] if extension.lower() == "csv": hparams = load_hparams_from_tags_csv(hparams_file) elif extension.lower() in ("yml", "yaml"): hparams = load_hparams_from_yaml(hparams_file) else: raise ValueError( ".csv, .yml or .yaml is required for `hparams_file`") hparams["on_gpu"] = False # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams check_point_version = old_checkpoint.get('checkpoint_version', 0) if check_point_version < 3: # need to do the transpose of query_key_value variables if hparams_file is not None: np = hparams['cfg']['num_attention_heads'] elif 'config' in old_checkpoint and 'num-attention-heads' in old_checkpoint[ 'config']: np = old_checkpoint['config']['num-attention-heads'] else: logging.warning("cannot determine the number attention heads") raise ValueError('need to know number of attention heads') if check_point_version == 0: # 3, np, hn -> np, 3, hn for key in checkpoint['state_dict']: if key.find('query_key_value') >= 0: weight = checkpoint['state_dict'][key] if len(weight.size()) == 2: # weight weight = weight.view(3, np, -1, weight.size()[-1]) weight = weight.transpose(0, 1).contiguous() checkpoint['state_dict'][key] = weight.view( -1, weight.size()[-1]) else: # biase weight = weight.view(3, np, -1) weight = weight.transpose(0, 1).contiguous() checkpoint['state_dict'][key] = weight.view(-1) elif check_point_version == 1: # np, hn, 3 -> np, 3, hn for key in checkpoint['state_dict']: if key.find('query_key_value') >= 0: weight = checkpoint['state_dict'][key] if len(weight.size()) == 2: # weight weight = weight.view(np, -1, 3, weight.size()[-1]) weight = weight.transpose(1, 2).contiguous() checkpoint['state_dict'][key] = weight else: # biase weight = weight.view(np, -1, 3) weight = weight.transpose(1, 2).contiguous() checkpoint['state_dict'][key] = weight # for past checkpoint need to add the new key if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} # override the hparams with values that were passed in # TODO: can we do this without overriding? config_kwargs = kwargs.copy() if 'trainer' in config_kwargs: config_kwargs.pop('trainer') checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config_kwargs) if 'cfg' in kwargs: model = cls._load_model_state(checkpoint, strict=strict, **kwargs) else: model = cls._load_model_state( checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs) # register the artifacts cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg if cfg.tokenizer.model is not None: model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model) if cfg.tokenizer.vocab_file is not None: model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file) if cfg.tokenizer.merge_file is not None: model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file) checkpoint = model finally: cls._set_model_restore_state(is_being_restored=False) return checkpoint
def load_from_checkpoint( cls, checkpoint_path: str, map_location: Any = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs, ): """ Loads ModelPT from checkpoint, with some maintenance of restoration. For documentation, please refer to LightningModule.load_from_checkpoin() documentation. """ checkpoint = None try: cls._set_model_restore_state(is_being_restored=True) # TODO: replace with proper PTL API with pl_legacy_patch(): if map_location is not None: checkpoint = pl_load(checkpoint_path, map_location=map_location) else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) if hparams_file is not None: extension = hparams_file.split(".")[-1] if extension.lower() == "csv": hparams = load_hparams_from_tags_csv(hparams_file) elif extension.lower() in ("yml", "yaml"): hparams = load_hparams_from_yaml(hparams_file) else: raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") hparams["on_gpu"] = False # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams # for past checkpoint need to add the new key if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} # override the hparams with values that were passed in cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get('cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) # TODO: can we do this without overriding? config_kwargs = kwargs.copy() if 'trainer' in config_kwargs: config_kwargs.pop('trainer') cfg.update(config_kwargs) if cfg.get('megatron_amp_O2', False): new_state_dict = {} for key in checkpoint['state_dict'].keys(): new_key = key.replace('model.', 'model.module.', 1) new_state_dict[new_key] = checkpoint['state_dict'][key] checkpoint['state_dict'] = new_state_dict if 'cfg' in kwargs: model = cls._load_model_state(checkpoint, strict=strict, **kwargs) else: model = cls._load_model_state(checkpoint, strict=strict, cfg=cfg, **kwargs) # cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg # NMT models do not have a `tokenizer` attribute, they instead have an encoder_tokenizer and decoder_tokenizer attribute. if hasattr(cfg, "tokenizer"): if cfg.tokenizer.get("tokenizer_model") is not None: model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.tokenizer_model) if cfg.tokenizer.get("vocab_file") is not None: model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file) if cfg.tokenizer.get("merge_file") is not None: model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file) if hasattr(cfg, "encoder_tokenizer"): if cfg.encoder_tokenizer.get("tokenizer_model") is not None: model.register_artifact("encoder_tokenizer.tokenizer_model", cfg.encoder_tokenizer.tokenizer_model) if cfg.encoder_tokenizer.get("vocab_file") is not None: model.register_artifact("encoder_tokenizer.vocab_file", cfg.encoder_tokenizer.vocab_file) if cfg.encoder_tokenizer.get("merge_file") is not None: model.register_artifact("encoder_tokenizer.merge_file", cfg.encoder_tokenizer.merge_file) if hasattr(cfg, "decoder_tokenizer"): if cfg.decoder_tokenizer.get("tokenizer_model") is not None: model.register_artifact("decoder_tokenizer.tokenizer_model", cfg.decoder_tokenizer.tokenizer_model) if cfg.decoder_tokenizer.get("vocab_file") is not None: model.register_artifact("decoder_tokenizer.vocab_file", cfg.decoder_tokenizer.vocab_file) if cfg.decoder_tokenizer.get("merge_file") is not None: model.register_artifact("decoder_tokenizer.merge_file", cfg.decoder_tokenizer.merge_file) checkpoint = model finally: cls._set_model_restore_state(is_being_restored=False) return checkpoint