Ejemplo n.º 1
0
def from_pretrained(
    model_name_or_path,
    checkpoint_file="model.pt",
    data_name_or_path=".",
    archive_map=None,
    **kwargs
):
    from fairseq import checkpoint_utils, file_utils

    if archive_map is not None:
        if model_name_or_path in archive_map:
            model_name_or_path = archive_map[model_name_or_path]
        if data_name_or_path is not None and data_name_or_path in archive_map:
            data_name_or_path = archive_map[data_name_or_path]

        # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
        # for each model
        if isinstance(model_name_or_path, dict):
            for k, v in model_name_or_path.items():
                if k == "checkpoint_file":
                    checkpoint_file = v
                elif (
                    k != "path"
                    # only set kwargs that don't already have overrides
                    and k not in kwargs
                ):
                    kwargs[k] = v
            model_name_or_path = model_name_or_path["path"]

    model_path = file_utils.load_archive_file(model_name_or_path)

    # convenience hack for loading data and BPE codes from model archive
    if data_name_or_path.startswith("."):
        kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
    else:
        kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
    for file, arg in {
        "code": "bpe_codes",
        "bpecodes": "bpe_codes",
        "sentencepiece.bpe.model": "sentencepiece_model",
        "merges.txt": "bpe_merges",
        "vocab.json": "bpe_vocab",
    }.items():
        path = os.path.join(model_path, file)
        if os.path.exists(path):
            kwargs[arg] = path

    if "user_dir" in kwargs:
        utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))

    models, args, task = checkpoint_utils.load_model_ensemble_and_task(
        [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
        arg_overrides=kwargs,
    )

    return {
        "args": args,
        "task": task,
        "models": models,
    }
Ejemplo n.º 2
0
    def from_pretrained(cls, parser, *inputs, model_name_or_path,
                        data_name_or_path, **kwargs):
        """
        Instantiate a FairseqModel from a pre-trained model file or pytorch state dict.
        Downloads and caches the pre-trained model file if needed.

        Params:
            pretrained_model_name_or_path: either
                - a str with the name of a pre-trained model to load
                - a path or url to a pretrained model state dict
        """
        from fairseq import checkpoint_utils, file_utils, options, tasks

        model_path = file_utils.load_archive_file(model_name_or_path)
        data_path = file_utils.load_archive_file(data_name_or_path)
        checkpoint_path = os.path.join(model_path, 'model.pt')

        # set data and parse
        model_args = options.parse_args_and_arch(parser,
                                                 input_args=[data_path])

        # override any kwargs passed in
        if kwargs is not None:
            for arg_name, arg_val in kwargs.items():
                setattr(model_args, arg_name, arg_val)

        print(model_args)

        task = tasks.setup_task(model_args)
        print("loading model checkpoint from {}".format(checkpoint_path))

        model, _model_args = checkpoint_utils.load_model_ensemble(
            [checkpoint_path], task=task)

        return model[0]
Ejemplo n.º 3
0
def from_pretrained(model_name_or_path,
                    checkpoint_file='model.pt',
                    data_name_or_path='.',
                    archive_map=None,
                    **kwargs):
    from fairseq import checkpoint_utils, file_utils

    if archive_map is not None:
        if model_name_or_path in archive_map:
            model_name_or_path = archive_map[model_name_or_path]
        if data_name_or_path is not None and data_name_or_path in archive_map:
            data_name_or_path = archive_map[data_name_or_path]

        # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
        # for each model
        if isinstance(model_name_or_path, dict):
            for k, v in model_name_or_path.items():
                if k == 'checkpoint_file':
                    checkpoint_file = v
                elif (k != 'path'
                      # only set kwargs that don't already have overrides
                      and k not in kwargs):
                    kwargs[k] = v
            model_name_or_path = model_name_or_path['path']

    model_path = file_utils.load_archive_file(model_name_or_path)

    # convenience hack for loading data and BPE codes from model archive
    if data_name_or_path.startswith('.'):
        kwargs['data'] = os.path.abspath(
            os.path.join(model_path, data_name_or_path))
    else:
        kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
    for file, arg in {
            'code': 'bpe_codes',
            'bpecodes': 'bpe_codes',
            'sentencepiece.bpe.model': 'sentencepiece_vocab',
    }.items():
        print(model_path, file)
        path = os.path.join(model_path, file)
        if os.path.exists(path):
            kwargs[arg] = path

    if 'user_dir' in kwargs:
        utils.import_user_module(
            argparse.Namespace(user_dir=kwargs['user_dir']))

    models, args, task = checkpoint_utils.load_model_ensemble_and_task(
        [
            os.path.join(model_path, cpt)
            for cpt in checkpoint_file.split(os.pathsep)
        ],
        arg_overrides=kwargs,
    )

    return {
        'args': args,
        'task': task,
        'models': models,
    }
Ejemplo n.º 4
0
    def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path=None, **kwargs):
        """
        Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
        file. Downloads and caches the pre-trained model file if needed.

        The base implementation returns a :class:`fairseq.hub_utils.Generator`,
        which can be used to generate translations or sample from language
        models. The underlying :class:`~fairseq.models.FairseqModel` can be
        accessed via the *generator.models* attribute.

        Other models may override this to implement custom PyTorch Hub APIs.

        Args:
            model_name_or_path (str): either the name of a pre-trained model to
                load or a path/URL to a pre-trained model state dict
            checkpoint_file (str, optional): colon-separated list of checkpoint
                files in the model archive to ensemble (default: 'model.pt')
            data_name_or_path (str, optional): point args.data to the archive
                at the given path/URL. Can start with '.' or './' to reuse the
                model archive path.
        """
        from fairseq import checkpoint_utils, file_utils, hub_utils

        if hasattr(cls, 'hub_models'):
            archive_map = cls.hub_models()
            if model_name_or_path in archive_map:
                model_name_or_path = archive_map[model_name_or_path]
            if data_name_or_path is not None and data_name_or_path in archive_map:
                data_name_or_path = archive_map[data_name_or_path]

        model_path = file_utils.load_archive_file(model_name_or_path)

        # convenience hack for loading data and BPE codes from model archive
        if data_name_or_path is not None:
            if data_name_or_path.startswith('.'):
                kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
            else:
                kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
        for file, arg in {
            'code': 'bpe_codes',
            'bpecodes': 'bpe_codes',
            'sentencepiece.bpe.model': 'sentencepiece_vocab',
        }.items():
            path = os.path.join(model_path, file)
            if os.path.exists(path):
                kwargs[arg] = path

        models, args, task = checkpoint_utils.load_model_ensemble_and_task(
            [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
            arg_overrides=kwargs,
        )

        print(args)

        return hub_utils.Generator(args, task, models)
Ejemplo n.º 5
0
    def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, checkpoint_file='model.pt', extra_task_args=None, **kwargs):
        from fairseq import file_utils

        model_path = file_utils.load_archive_file(model_name_or_path)
        data_path = file_utils.load_archive_file(data_name_or_path)
        checkpoint_path = os.path.join(model_path, checkpoint_file)

        task_name = kwargs.get('task', 'translation')

        # set data and parse
        model_args = options.parse_args_and_arch(
            parser,
            input_args=[data_path, '--task', task_name] + (extra_task_args or [])
        )

        # override any kwargs passed in
        if kwargs is not None:
            for arg_name, arg_val in kwargs.items():
                setattr(model_args, arg_name, arg_val)

        utils.import_user_module(args)

        if model_args.buffer_size < 1:
            model_args.buffer_size = 1
        if model_args.max_tokens is None and model_args.max_sentences is None:
            model_args.max_sentences = 1

        assert not model_args.sampling or model_args.nbest == model_args.beam, \
            '--sampling requires --nbest to be equal to --beam'
        assert not model_args.max_sentences or model_args.max_sentences <= model_args.buffer_size, \
            '--max-sentences/--batch-size cannot be larger than --buffer-size'

        print(model_args)

        task = tasks.setup_task(model_args)
        print("loading model checkpoint from {}".format(checkpoint_path))

        model, _model_args = checkpoint_utils.load_model_ensemble(
            [checkpoint_path],
            task=task,
            arg_overrides=kwargs,
        )

        src_bpe = None
        for bpe in ['bpecodes', 'vocab.bpe', 'sentencepiece.bpe.model']:
            path = os.path.join(model_path, bpe)
            if os.path.exists(path):
                src_bpe = path
                break

        return cls(task, model, model_args, src_bpe, kwargs.get('remove_bpe', '@@ '))
Ejemplo n.º 6
0
def from_pretrained(
    model_name_or_path,
    checkpoint_file='model.pt',
    data_name_or_path='.',
    archive_map=None,
    **kwargs
):
    from fairseq import checkpoint_utils, file_utils

    if archive_map is not None:
        if model_name_or_path in archive_map:
            model_name_or_path = archive_map[model_name_or_path]
        if data_name_or_path is not None and data_name_or_path in archive_map:
            data_name_or_path = archive_map[data_name_or_path]

    model_path = file_utils.load_archive_file(model_name_or_path)

    # convenience hack for loading data and BPE codes from model archive
    if data_name_or_path.startswith('.'):
        kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
    else:
        kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
    for file, arg in {
        'code': 'bpe_codes',
        'bpecodes': 'bpe_codes',
        'sentencepiece.bpe.model': 'sentencepiece_vocab',
    }.items():
        path = os.path.join(model_path, file)
        if os.path.exists(path):
            kwargs[arg] = path

    if 'user_dir' in kwargs:
        utils.import_user_module(argparse.Namespace(user_dir=kwargs['user_dir']))

    models, args, task = checkpoint_utils.load_model_ensemble_and_task(
        [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
        arg_overrides=kwargs,
    )

    return {
        'args': args,
        'task': task,
        'models': models,
    }