Beispiel #1
0
    def load_from_metrics(cls, weights_path, tags_csv, overwrite_hparams=None):
        overwrite_hparams = overwrite_hparams or {}

        hparams = load_hparams_from_tags_csv(tags_csv)
        hparams.__dict__["logger"] = eval(
            hparams.__dict__.get("logger", "None"))
        if (str(hparams.sampling_probs) == "nan"
                or str(hparams.sampling_probs) == "None"
                or len(hparams.sampling_probs) == 0):
            hparams.__dict__["sampling_probs"] = None

        if (str(hparams.audio_file) == "nan"
                or str(hparams.audio_file) == "None"
                or len(hparams.audio_file) == 0):
            hparams.__dict__["audio_file"] = None

        hparams.__setattr__("on_gpu", False)
        hparams.__dict__.update(overwrite_hparams)

        # load on CPU only to avoid OOM issues
        # then its up to user to put back on GPUs
        checkpoint = torch.load(weights_path,
                                map_location=lambda storage, loc: storage)

        # load the state_dict on the model automatically
        model = cls(hparams)
        model.load_state_dict(checkpoint["state_dict"])

        optimizer = model.configure_optimizers()
        optimizer.load_state_dict(checkpoint["optimizer_states"][0])

        # give model a chance to load something
        model.on_load_checkpoint(checkpoint)

        return model
Beispiel #2
0
def test_loading_meta_tags(tmpdir):
    """ test for backward compatibility to meta_tags.csv """
    tutils.reset_seed()

    hparams = EvalModelTemplate.get_default_hparams()

    # save tags
    logger = tutils.get_default_logger(tmpdir)
    logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0))
    logger.log_hyperparams(hparams)
    logger.save()

    # load hparams
    path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir)
    hparams_path = os.path.join(path_expt_dir,
                                TensorBoardLogger.NAME_HPARAMS_FILE)
    hparams = load_hparams_from_yaml(hparams_path)

    # save as legacy meta_tags.csv
    tags_path = os.path.join(path_expt_dir, 'meta_tags.csv')
    save_hparams_to_tags_csv(tags_path, hparams)

    tags = load_hparams_from_tags_csv(tags_path)

    assert hparams == tags
Beispiel #3
0
    def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
        r"""
        You should use `load_from_checkpoint` instead!
        However, if your .ckpt weights don't have the hyperparameters saved, use this method  to pass
        in a .csv with the hparams you'd like to use. These will  be converted  into a argparse.Namespace
        and passed into  your LightningModule for use.

        Args:

            weights_path (str): Path to a PyTorch checkpoint
            tags_csv (str): Path to a .csv with two columns (key, value) as in this

            Example::
                key,value
                drop_prob,0.2
                batch_size,32

            map_location (dict | str | torch.device | function):
                If your checkpoint saved a GPU model and you now load on CPUs
                or a different number of GPUs, use this to map to the new setup
                (example: {'cuda:1':'cuda:0'}).
                The behaviour is the same as in
                `torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.

        Return:
            LightningModule with loaded weights and hyperparameters (if available).

        Example
        -------
        .. code-block:: python

            pretrained_model = MyLightningModule.load_from_metrics(
                weights_path='/path/to/pytorch_checkpoint.ckpt',
                tags_csv='/path/to/hparams_file.csv',
                on_gpu=True,
                map_location=None
            )

            # predict
            pretrained_model.eval()
            pretrained_model.freeze()
            y_hat = pretrained_model(x)
        """

        hparams = load_hparams_from_tags_csv(tags_csv)
        hparams.__setattr__('on_gpu', False)

        if map_location is not None:
            checkpoint = torch.load(weights_path, map_location=map_location)
        else:
            checkpoint = torch.load(weights_path,
                                    map_location=lambda storage, loc: storage)

        # add the hparams from csv file to checkpoint
        checkpoint['hparams'] = vars(hparams)

        model = cls._load_model_state(checkpoint)
        return model
Beispiel #4
0
    def load_from_checkpoint(
        cls,
        checkpoint_path: str,
        map_location: Any = None,
        hparams_file: Optional[str] = None,
        strict: bool = True,
        **kwargs,
    ):
        """
        Loads ModelPT from checkpoint, with some maintenance of restoration.
        For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
        """
        checkpoint = None
        try:
            cls._set_model_restore_state(is_being_restored=True)
            # TODO: replace with proper PTL API
            with pl_legacy_patch():
                if map_location is not None:
                    checkpoint = pl_load(checkpoint_path, map_location=map_location)
                else:
                    checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

            if hparams_file is not None:
                extension = hparams_file.split(".")[-1]
                if extension.lower() == "csv":
                    hparams = load_hparams_from_tags_csv(hparams_file)
                elif extension.lower() in ("yml", "yaml"):
                    hparams = load_hparams_from_yaml(hparams_file)
                else:
                    raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")

                hparams["on_gpu"] = False

                # overwrite hparams by the given file
                checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

            # for past checkpoint need to add the new key
            if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
                checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
            # override the hparams with values that were passed in
            # TODO: can we do this without overriding?
            config_kwargs = kwargs.copy()
            if 'trainer' in config_kwargs:
                config_kwargs.pop('trainer')
            checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config_kwargs)

            if 'cfg' in kwargs:
                model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
            else:
                model = cls._load_model_state(
                    checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs
                )
            checkpoint = model

        finally:
            cls._set_model_restore_state(is_being_restored=False)
        return checkpoint
Beispiel #5
0
    def load_from_checkpoint(
            cls,
            checkpoint_path: str,
            map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
            tags_csv: Optional[str] = None,
    ) -> 'LightningModule':
        r"""

        Primary way of loading model from a checkpoint. When Lightning saves a checkpoint
        it stores the hyperparameters in the checkpoint if you initialized your LightningModule
        with an argument called `hparams` which is a Namespace (output of using argparse
        to parse command line arguments).

        Example
        -------
        .. code-block:: python

            from argparse import Namespace
            hparams = Namespace(**{'learning_rate': 0.1})

            model = MyModel(hparams)

            class MyModel(LightningModule):
                def __init__(self, hparams):
                    self.learning_rate = hparams.learning_rate

        Args:
            checkpoint_path: Path to checkpoint.
            map_location:
                If your checkpoint saved a GPU model and you now load on CPUs
                or a different number of GPUs, use this to map to the new setup.
                The behaviour is the same as in
                `torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.
            tags_csv: Optional path to a .csv file with two columns (key, value)
                as in this example::

                    key,value
                    drop_prob,0.2
                    batch_size,32

                You most likely won't need this since Lightning will always save the hyperparameters
                to the checkpoint.
                However, if your checkpoint weights don't have the hyperparameters saved,
                use this method to pass in a .csv file with the hparams you'd like to use.
                These will be converted into a argparse.Namespace and passed into your
                LightningModule for use.

        Return:
            LightningModule with loaded weights and hyperparameters (if available).

        Example
        -------
        .. code-block:: python

            # load weights without mapping ...
            MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

            # or load weights mapping all weights from GPU 1 to GPU 0 ...
            map_location = {'cuda:1':'cuda:0'}
            MyLightningModule.load_from_checkpoint(
                'path/to/checkpoint.ckpt',
                map_location=map_location
            )

            # or load weights and hyperparameters from separate files.
            MyLightningModule.load_from_checkpoint(
                'path/to/checkpoint.ckpt',
                tags_csv='/path/to/hparams_file.csv'
            )

            # predict
            pretrained_model.eval()
            pretrained_model.freeze()
            y_hat = pretrained_model(x)
        """
        if map_location is not None:
            checkpoint = torch.load(checkpoint_path, map_location=map_location)
        else:
            checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

        if tags_csv is not None:
            # add the hparams from csv file to checkpoint
            hparams = load_hparams_from_tags_csv(tags_csv)
            hparams.__setattr__('on_gpu', False)
            checkpoint['hparams'] = vars(hparams)

        model = cls._load_model_state(checkpoint)
        return model
def load_from_checkpoint(
    cls,
    checkpoint_path: str,
    map_location: Any = None,
    hparams_file: Optional[str] = None,
    strict: bool = True,
    **kwargs,
):
    """
        Loads Megatron_LM checkpoints, convert it, with some maintenance of restoration.
        For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
        """
    checkpoint = None
    try:
        cls._set_model_restore_state(is_being_restored=True)
        # TODO: replace with proper PTL API

        with pl_legacy_patch():
            if map_location is not None:
                old_checkpoint = pl_load(checkpoint_path,
                                         map_location=map_location)
            else:
                old_checkpoint = pl_load(
                    checkpoint_path, map_location=lambda storage, loc: storage)

        total_params = [0]
        checkpoint = OrderedDict()
        checkpoint['state_dict'] = OrderedDict()
        parse_weights(old_checkpoint['model'],
                      "",
                      total_params,
                      checkpoint['state_dict'],
                      translator=kwargs['translator'])
        print('converted {:.2f}M parameters'.format(total_params[0] / 1e6))

        if hparams_file is not None:
            extension = hparams_file.split(".")[-1]
            if extension.lower() == "csv":
                hparams = load_hparams_from_tags_csv(hparams_file)
            elif extension.lower() in ("yml", "yaml"):
                hparams = load_hparams_from_yaml(hparams_file)
            else:
                raise ValueError(
                    ".csv, .yml or .yaml is required for `hparams_file`")

            hparams["on_gpu"] = False

            # overwrite hparams by the given file
            checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

        check_point_version = old_checkpoint.get('checkpoint_version', 0)
        if check_point_version < 3:
            # need to do the transpose of query_key_value variables
            if hparams_file is not None:
                np = hparams['cfg']['num_attention_heads']
            elif 'config' in old_checkpoint and 'num-attention-heads' in old_checkpoint[
                    'config']:
                np = old_checkpoint['config']['num-attention-heads']

            else:
                logging.warning("cannot determine the number attention heads")
                raise ValueError('need to know number of attention heads')

            if check_point_version == 0:
                # 3, np, hn -> np, 3, hn
                for key in checkpoint['state_dict']:
                    if key.find('query_key_value') >= 0:
                        weight = checkpoint['state_dict'][key]
                        if len(weight.size()) == 2:
                            # weight
                            weight = weight.view(3, np, -1, weight.size()[-1])
                            weight = weight.transpose(0, 1).contiguous()
                            checkpoint['state_dict'][key] = weight.view(
                                -1,
                                weight.size()[-1])
                        else:
                            # biase
                            weight = weight.view(3, np, -1)
                            weight = weight.transpose(0, 1).contiguous()
                            checkpoint['state_dict'][key] = weight.view(-1)
            elif check_point_version == 1:
                # np, hn, 3 -> np, 3, hn
                for key in checkpoint['state_dict']:
                    if key.find('query_key_value') >= 0:
                        weight = checkpoint['state_dict'][key]
                        if len(weight.size()) == 2:
                            # weight
                            weight = weight.view(np, -1, 3, weight.size()[-1])
                            weight = weight.transpose(1, 2).contiguous()
                            checkpoint['state_dict'][key] = weight
                        else:
                            # biase
                            weight = weight.view(np, -1, 3)
                            weight = weight.transpose(1, 2).contiguous()
                            checkpoint['state_dict'][key] = weight

        # for past checkpoint need to add the new key
        if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
            checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
        # override the hparams with values that were passed in
        # TODO: can we do this without overriding?
        config_kwargs = kwargs.copy()
        if 'trainer' in config_kwargs:
            config_kwargs.pop('trainer')
        checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(config_kwargs)

        if 'cfg' in kwargs:
            model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
        else:
            model = cls._load_model_state(
                checkpoint,
                strict=strict,
                cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg,
                **kwargs)
            # register the artifacts
            cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg
            if cfg.tokenizer.model is not None:
                model.register_artifact("tokenizer.tokenizer_model",
                                        cfg.tokenizer.model)
            if cfg.tokenizer.vocab_file is not None:
                model.register_artifact("tokenizer.vocab_file",
                                        cfg.tokenizer.vocab_file)
            if cfg.tokenizer.merge_file is not None:
                model.register_artifact("tokenizer.merge_file",
                                        cfg.tokenizer.merge_file)
        checkpoint = model

    finally:
        cls._set_model_restore_state(is_being_restored=False)
    return checkpoint
Beispiel #7
0
    def load_from_checkpoint(
        cls,
        checkpoint_path: str,
        map_location: Any = None,
        hparams_file: Optional[str] = None,
        strict: bool = True,
        **kwargs,
    ):
        """
        Loads ModelPT from checkpoint, with some maintenance of restoration.
        For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
        """
        checkpoint = None
        try:
            cls._set_model_restore_state(is_being_restored=True)
            # TODO: replace with proper PTL API
            with pl_legacy_patch():
                if map_location is not None:
                    checkpoint = pl_load(checkpoint_path, map_location=map_location)
                else:
                    checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

            if hparams_file is not None:
                extension = hparams_file.split(".")[-1]
                if extension.lower() == "csv":
                    hparams = load_hparams_from_tags_csv(hparams_file)
                elif extension.lower() in ("yml", "yaml"):
                    hparams = load_hparams_from_yaml(hparams_file)
                else:
                    raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")

                hparams["on_gpu"] = False

                # overwrite hparams by the given file
                checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

            # for past checkpoint need to add the new key
            if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
                checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
            # override the hparams with values that were passed in
            cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get('cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY])
            # TODO: can we do this without overriding?
            config_kwargs = kwargs.copy()
            if 'trainer' in config_kwargs:
                config_kwargs.pop('trainer')
            cfg.update(config_kwargs)

            if cfg.get('megatron_amp_O2', False):
                new_state_dict = {}
                for key in checkpoint['state_dict'].keys():
                    new_key = key.replace('model.', 'model.module.', 1)
                    new_state_dict[new_key] = checkpoint['state_dict'][key]
                checkpoint['state_dict'] = new_state_dict

            if 'cfg' in kwargs:
                model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
            else:
                model = cls._load_model_state(checkpoint, strict=strict, cfg=cfg, **kwargs)
                # cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg

            # NMT models do not have a `tokenizer` attribute, they instead have an encoder_tokenizer and decoder_tokenizer attribute.
            if hasattr(cfg, "tokenizer"):
                if cfg.tokenizer.get("tokenizer_model") is not None:
                    model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.tokenizer_model)
                if cfg.tokenizer.get("vocab_file") is not None:
                    model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file)
                if cfg.tokenizer.get("merge_file") is not None:
                    model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file)

            if hasattr(cfg, "encoder_tokenizer"):
                if cfg.encoder_tokenizer.get("tokenizer_model") is not None:
                    model.register_artifact("encoder_tokenizer.tokenizer_model", cfg.encoder_tokenizer.tokenizer_model)
                if cfg.encoder_tokenizer.get("vocab_file") is not None:
                    model.register_artifact("encoder_tokenizer.vocab_file", cfg.encoder_tokenizer.vocab_file)
                if cfg.encoder_tokenizer.get("merge_file") is not None:
                    model.register_artifact("encoder_tokenizer.merge_file", cfg.encoder_tokenizer.merge_file)

            if hasattr(cfg, "decoder_tokenizer"):
                if cfg.decoder_tokenizer.get("tokenizer_model") is not None:
                    model.register_artifact("decoder_tokenizer.tokenizer_model", cfg.decoder_tokenizer.tokenizer_model)
                if cfg.decoder_tokenizer.get("vocab_file") is not None:
                    model.register_artifact("decoder_tokenizer.vocab_file", cfg.decoder_tokenizer.vocab_file)
                if cfg.decoder_tokenizer.get("merge_file") is not None:
                    model.register_artifact("decoder_tokenizer.merge_file", cfg.decoder_tokenizer.merge_file)

            checkpoint = model

        finally:
            cls._set_model_restore_state(is_being_restored=False)
        return checkpoint