예제 #1
0
    def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs):
        """
        Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
        file. Downloads and caches the pre-trained model file if needed.

        The base implementation returns a :class:`fairseq.hub_utils.Generator`,
        which can be used to generate translations or sample from language
        models. The underlying :class:`~fairseq.models.FairseqModel` can be
        accessed via the *generator.models* attribute.

        Other models may override this to implement custom PyTorch Hub APIs.

        Args:
            model_name_or_path (str): either the name of a pre-trained model to
                load or a path/URL to a pre-trained model state dict
            checkpoint_file (str, optional): colon-separated list of checkpoint
                files in the model archive to ensemble (default: 'model.pt')
            data_name_or_path (str, optional): point args.data to the archive
                at the given path/URL. Can start with '.' or './' to reuse the
                model archive path.
        """
        from fairseq import hub_utils
        x = hub_utils.from_pretrained(
            model_name_or_path,
            checkpoint_file,
            data_name_or_path,
            archive_map=cls.hub_models(),
            **kwargs,
        )
        print(x['args'])
        return hub_utils.Generator(x['args'], x['task'], x['models'])
예제 #2
0
    def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path=None, **kwargs):
        """
        Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
        file. Downloads and caches the pre-trained model file if needed.

        The base implementation returns a :class:`fairseq.hub_utils.Generator`,
        which can be used to generate translations or sample from language
        models. The underlying :class:`~fairseq.models.FairseqModel` can be
        accessed via the *generator.models* attribute.

        Other models may override this to implement custom PyTorch Hub APIs.

        Args:
            model_name_or_path (str): either the name of a pre-trained model to
                load or a path/URL to a pre-trained model state dict
            checkpoint_file (str, optional): colon-separated list of checkpoint
                files in the model archive to ensemble (default: 'model.pt')
            data_name_or_path (str, optional): point args.data to the archive
                at the given path/URL. Can start with '.' or './' to reuse the
                model archive path.
        """
        from fairseq import checkpoint_utils, file_utils, hub_utils

        if hasattr(cls, 'hub_models'):
            archive_map = cls.hub_models()
            if model_name_or_path in archive_map:
                model_name_or_path = archive_map[model_name_or_path]
            if data_name_or_path is not None and data_name_or_path in archive_map:
                data_name_or_path = archive_map[data_name_or_path]

        model_path = file_utils.load_archive_file(model_name_or_path)

        # convenience hack for loading data and BPE codes from model archive
        if data_name_or_path is not None:
            if data_name_or_path.startswith('.'):
                kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
            else:
                kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
        for file, arg in {
            'code': 'bpe_codes',
            'bpecodes': 'bpe_codes',
            'sentencepiece.bpe.model': 'sentencepiece_vocab',
        }.items():
            path = os.path.join(model_path, file)
            if os.path.exists(path):
                kwargs[arg] = path

        models, args, task = checkpoint_utils.load_model_ensemble_and_task(
            [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
            arg_overrides=kwargs,
        )

        print(args)

        return hub_utils.Generator(args, task, models)