Exemplo n.º 1
0
    def save_pretrained(self, save_directory: Union[str, os.PathLike]):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method

        Arguments:
            save_directory (:obj:`str` or :obj:`os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
        """
        if os.path.isfile(save_directory):
            logger.error(
                "Provided path ({}) should be a directory, not a file".format(
                    save_directory))
            return
        os.makedirs(save_directory, exist_ok=True)

        # get abs dir
        save_directory = os.path.abspath(save_directory)
        # save config as well
        self.config.save_pretrained(save_directory)

        # save model
        with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
            model_bytes = to_bytes(self.params)
            f.write(model_bytes)
Exemplo n.º 2
0
    def save_vocabulary(self,
                        save_directory: str,
                        filename_prefix: Optional[str] = None) -> Tuple[str]:
        if not os.path.isdir(save_directory):
            logger.error("Vocabulary path ({}) should be a directory".format(
                save_directory))
            return
        vocab_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "") +
            VOCAB_FILES_NAMES["vocab_file"])
        merge_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "") +
            VOCAB_FILES_NAMES["merges_file"])

        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.encoder, ensure_ascii=False))

        index = 0
        with open(merge_file, "w", encoding="utf-8") as writer:
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(),
                                                  key=lambda kv: kv[1]):
                if index != token_index:
                    logger.warning(
                        "Saving vocabulary to {}: BPE merge indices are not consecutive."
                        " Please check that the tokenizer is not corrupted!".
                        format(merge_file))
                    index = token_index
                writer.write(" ".join(bpe_tokens) + "\n")
                index += 1

        return vocab_file, merge_file
Exemplo n.º 3
0
    def __init__(self,
                 arch,
                 model_dir,
                 src_vocab_path=None,
                 trg_vocab_path=None,
                 embed_size=50,
                 hidden_size=50,
                 dropout=0.5,
                 max_length=128):
        logger.debug("device: {}".format(device))
        if arch in ['seq2seq', 'convseq2seq']:
            self.src_2_ids = load_word_dict(src_vocab_path)
            self.trg_2_ids = load_word_dict(trg_vocab_path)
            self.id_2_trgs = {v: k for k, v in self.trg_2_ids.items()}
            if arch == 'seq2seq':
                logger.debug('use seq2seq model.')
                self.model = Seq2Seq(encoder_vocab_size=len(self.src_2_ids),
                                     decoder_vocab_size=len(self.trg_2_ids),
                                     embed_size=embed_size,
                                     enc_hidden_size=hidden_size,
                                     dec_hidden_size=hidden_size,
                                     dropout=dropout).to(device)
                model_path = os.path.join(model_dir, 'seq2seq.pth')
                self.model.load_state_dict(torch.load(model_path))
                self.model.eval()
            else:
                logger.debug('use convseq2seq model.')
                trg_pad_idx = self.trg_2_ids[PAD_TOKEN]
                self.model = ConvSeq2Seq(
                    encoder_vocab_size=len(self.src_2_ids),
                    decoder_vocab_size=len(self.trg_2_ids),
                    embed_size=embed_size,
                    enc_hidden_size=hidden_size,
                    dec_hidden_size=hidden_size,
                    dropout=dropout,
                    trg_pad_idx=trg_pad_idx,
                    device=device,
                    max_length=max_length).to(device)
                model_path = os.path.join(model_dir, 'convseq2seq.pth')
                self.model.load_state_dict(torch.load(model_path))
                self.model.eval()
        elif arch == 'bertseq2seq':
            # Bert Seq2seq model
            logger.debug('use bert seq2seq model.')
            use_cuda = True if torch.cuda.is_available() else False

            # encoder_type=None, encoder_name=None, decoder_name=None
            self.model = Seq2SeqModel("bert",
                                      "{}/encoder".format(model_dir),
                                      "{}/decoder".format(model_dir),
                                      use_cuda=use_cuda)
        else:
            logger.error('error arch: {}'.format(arch))
            raise ValueError(
                "Model arch choose error. Must use one of seq2seq model.")
        self.arch = arch
        self.max_length = max_length
Exemplo n.º 4
0
def load_word_dict(save_path):
    dict_data = dict()
    with open(save_path, 'r', encoding='utf-8') as f:
        for line in f:
            items = line.strip().split()
            try:
                dict_data[items[0]] = int(items[1])
            except IndexError:
                logger.error('error', line)
    return dict_data
    def mask_token(self) -> str:
        """
        :obj:`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while
        not having been set.

        Roberta tokenizer has a special mask token to be usble in the fill-mask pipeline. The mask token will greedily
        comprise the space before the `<mask>`.
        """
        if self._mask_token is None and self.verbose:
            logger.error("Using mask_token, but it is not set yet.")
            return None
        return str(self._mask_token)
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        if not os.path.isdir(save_directory):
            logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
            return
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        return (out_vocab_file,)
Exemplo n.º 7
0
 def __init__(self,
              arch,
              model_path,
              src_vocab_path,
              trg_vocab_path,
              embed_size=50,
              hidden_size=50,
              dropout=0.5,
              max_length=128):
     self.src_2_ids = load_word_dict(src_vocab_path)
     self.trg_2_ids = load_word_dict(trg_vocab_path)
     self.id_2_trgs = {v: k for k, v in self.trg_2_ids.items()}
     if arch == 'seq2seq':
         self.model = Seq2Seq(encoder_vocab_size=len(self.src_2_ids),
                              decoder_vocab_size=len(self.trg_2_ids),
                              embed_size=embed_size,
                              enc_hidden_size=hidden_size,
                              dec_hidden_size=hidden_size,
                              dropout=dropout).to(device)
         self.model.load_state_dict(torch.load(model_path))
         self.model.eval()
     elif arch == 'convseq2seq':
         trg_pad_idx = self.trg_2_ids[PAD_TOKEN]
         self.model = ConvSeq2Seq(encoder_vocab_size=len(self.src_2_ids),
                                  decoder_vocab_size=len(self.trg_2_ids),
                                  embed_size=embed_size,
                                  enc_hidden_size=hidden_size,
                                  dec_hidden_size=hidden_size,
                                  dropout=dropout,
                                  trg_pad_idx=trg_pad_idx,
                                  device=device,
                                  max_length=max_length).to(device)
         self.model.load_state_dict(torch.load(model_path))
         self.model.eval()
     elif arch == 'bertseq2seq':
         # Bert Seq2seq model
         use_cuda = True if torch.cuda.is_available() else False
         # encoder_type=None, encoder_name=None, decoder_name=None
         self.model = Seq2SeqModel("bert",
                                   "output/bertseq2seq/encoder",
                                   "output/bertseq2seq/decoder",
                                   use_cuda=use_cuda)
         print(self.model)
     else:
         logger.error('error arch: {}'.format(arch))
         raise ValueError(
             "Model arch choose error. Must use one of seq2seq model.")
     self.arch = arch
     self.max_length = max_length
Exemplo n.º 8
0
    def __init__(self, **kwargs):
        # Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers)
        self.model_details = kwargs.pop("model_details", {})
        self.intended_use = kwargs.pop("intended_use", {})
        self.factors = kwargs.pop("factors", {})
        self.metrics = kwargs.pop("metrics", {})
        self.evaluation_data = kwargs.pop("evaluation_data", {})
        self.training_data = kwargs.pop("training_data", {})
        self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
        self.ethical_considerations = kwargs.pop("ethical_considerations", {})
        self.caveats_and_recommendations = kwargs.pop(
            "caveats_and_recommendations", {})

        # Open additional attributes
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error("Can't set {} with value {} for {}".format(
                    key, value, self))
                raise err
Exemplo n.º 9
0
    def get_config_dict(cls, pretrained_model_name_or_path: Union[str,
                                                                  os.PathLike],
                        **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
        From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
        :class:`~transformers.PretrainedConfig` using ``from_dict``.

        Parameters:
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

        Returns:
            :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.

        """
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        use_auth_token = kwargs.pop("use_auth_token", None)
        local_files_only = kwargs.pop("local_files_only", False)
        revision = kwargs.pop("revision", None)

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
        if os.path.isdir(pretrained_model_name_or_path):
            config_file = os.path.join(pretrained_model_name_or_path,
                                       CONFIG_NAME)
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
                pretrained_model_name_or_path):
            config_file = pretrained_model_name_or_path
        else:
            config_file = hf_bucket_url(pretrained_model_name_or_path,
                                        filename=CONFIG_NAME,
                                        revision=revision,
                                        mirror=None)

        try:
            # Load from URL or cache if already cached
            resolved_config_file = cached_path(
                config_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
            )
            # Load config dict
            config_dict = cls._dict_from_json_file(resolved_config_file)

        except EnvironmentError as err:
            logger.error(err)
            msg = (
                f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
            )
            raise EnvironmentError(msg)

        except json.JSONDecodeError:
            msg = (
                "Couldn't reach server at '{}' to download configuration file or "
                "configuration file is not a valid JSON file. "
                "Please check network or file content here: {}.".format(
                    config_file, resolved_config_file))
            raise EnvironmentError(msg)

        if resolved_config_file == config_file:
            logger.info("loading configuration file {}".format(config_file))
        else:
            logger.info(
                "loading configuration file {} from cache at {}".format(
                    config_file, resolved_config_file))

        return config_dict, kwargs
Exemplo n.º 10
0
    def __init__(self, **kwargs):
        # Attributes with defaults
        self.return_dict = kwargs.pop("return_dict", True)
        self.output_hidden_states = kwargs.pop("output_hidden_states", False)
        self.output_attentions = kwargs.pop("output_attentions", False)
        self.torchscript = kwargs.pop("torchscript",
                                      False)  # Only used by PyTorch models
        self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
        self.pruned_heads = kwargs.pop("pruned_heads", {})
        self.tie_word_embeddings = kwargs.pop(
            "tie_word_embeddings", True
        )  # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.

        # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
        self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
        self.is_decoder = kwargs.pop("is_decoder", False)
        self.add_cross_attention = kwargs.pop("add_cross_attention", False)
        self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)

        # Parameters for sequence generation
        self.max_length = kwargs.pop("max_length", 20)
        self.min_length = kwargs.pop("min_length", 0)
        self.do_sample = kwargs.pop("do_sample", False)
        self.early_stopping = kwargs.pop("early_stopping", False)
        self.num_beams = kwargs.pop("num_beams", 1)
        self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
        self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
        self.temperature = kwargs.pop("temperature", 1.0)
        self.top_k = kwargs.pop("top_k", 50)
        self.top_p = kwargs.pop("top_p", 1.0)
        self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
        self.length_penalty = kwargs.pop("length_penalty", 1.0)
        self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
        self.bad_words_ids = kwargs.pop("bad_words_ids", None)
        self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
        self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)

        # Fine-tuning task arguments
        self.architectures = kwargs.pop("architectures", None)
        self.finetuning_task = kwargs.pop("finetuning_task", None)
        self.id2label = kwargs.pop("id2label", None)
        self.label2id = kwargs.pop("label2id", None)
        if self.id2label is not None:
            kwargs.pop("num_labels", None)
            self.id2label = dict(
                (int(key), value) for key, value in self.id2label.items())
            # Keys are always strings in JSON so convert ids to int here.
        else:
            self.num_labels = kwargs.pop("num_labels", 2)

        # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
        self.tokenizer_class = kwargs.pop("tokenizer_class", None)
        self.prefix = kwargs.pop("prefix", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)
        self.sep_token_id = kwargs.pop("sep_token_id", None)

        self.decoder_start_token_id = kwargs.pop("decoder_start_token_id",
                                                 None)

        # task specific arguments
        self.task_specific_params = kwargs.pop("task_specific_params", None)

        # TPU arguments
        self.xla_device = kwargs.pop("xla_device", None)

        # Name or path to the pretrained checkpoint
        self._name_or_path = str(kwargs.pop("name_or_path", ""))

        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error("Can't set {} with value {} for {}".format(
                    key, value, self))
                raise err
Exemplo n.º 11
0
    def from_pretrained(cls,
                        pretrained_model_name_or_path: Union[str, os.PathLike],
                        dtype: jnp.dtype = jnp.float32,
                        *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained flax model from a pre-trained model configuration.

        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.

        Parameters:
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
                Can be either:

                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
                      case, ``from_pt`` should be set to :obj:`True`.
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `model id` string of a pretrained
                      model).
                    - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
                      by supplying the save directory.
                    - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
                      configuration JSON file named `config.json` is found in the directory.
            cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch checkpoint save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str], `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            >>> from transformers import BertConfig, FlaxBertModel
            >>> # Download model and configuration from huggingface.co and cache.
            >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
            >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/config.json')
            >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
        """
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                FLAX_WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False"
                        .format(
                            [FLAX_WEIGHTS_NAME, WEIGHTS_NAME],
                            pretrained_model_name_or_path,
                        ))
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
                    revision=revision,
                )

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                )
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(
                    f"loading weights file {archive_file} from cache at {resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # Instantiate model.
        with open(resolved_archive_file, "rb") as state_f:
            try:
                if from_pt:
                    import torch

                    state = torch.load(state_f)

                    state = convert_state_dict_from_pt(cls, state, config)
                else:
                    state = from_bytes(cls, state_f.read())
            except UnpicklingError:
                raise EnvironmentError(
                    f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
                )

        # init random models
        model = cls(config, *model_args, **model_kwargs)

        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(
                model.params) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]

        # flatten dicts
        state = flatten_dict(state)
        random_state = flatten_dict(unfreeze(model.params))

        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

        # add missing keys as random parameters
        for missing_key in missing_keys:
            state[missing_key] = random_state[missing_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(
                f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
            )

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
        else:
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )

        # set correct parameters
        model.params = unflatten_dict(state)
        return model
Exemplo n.º 12
0
    def _tokenize(self, text, lang="en", bypass_tokenizer=False):
        """
        Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific
        tokenizerself. Otherwise, we use Moses.

        Details of tokenization:

            - [sacremoses](https://github.com/alvations/sacremoses): port of Moses
            - Install with `pip install sacremoses`
            - [pythainlp](https://github.com/PyThaiNLP/pythainlp): Thai tokenizer
            - Install with `pip install pythainlp`
            - [kytea](https://github.com/chezou/Mykytea-python): Japanese tokenizer, wrapper of
              [KyTea](https://github.com/neubig/kytea)
            - Install with the following steps:

            ::

                git clone [email protected]:neubig/kytea.git && cd kytea
                autoreconf -i
                ./configure --prefix=$HOME/local
                make && make install
                pip install kytea

            - [jieba](https://github.com/fxsjy/jieba): Chinese tokenizer (*)
            - Install with `pip install jieba`

        (*) The original XLM used [Stanford
        Segmenter](https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip). However, the wrapper
        (`nltk.tokenize.stanford_segmenter`) is slow due to JVM overhead, and it will be deprecated. Jieba is a lot
        faster and pip-installable. Note there is some mismatch with the Stanford Segmenter. It should be fine if you
        fine-tune the model with Chinese supervisionself. If you want the same exact behaviour, use the original XLM
        [preprocessing script](https://github.com/facebookresearch/XLM/tree/master/tools) to tokenize the sentence
        externally, and set `bypass_tokenizer=True` to bypass the tokenizer.

        Args:

            - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported
              languages. However, we don't enforce it.
            - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False)
              (bool). If True, we only apply BPE.

        Returns:
            List of tokens.
        """
        if lang and self.lang2id and lang not in self.lang2id:
            logger.error(
                "Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
            )
        if bypass_tokenizer:
            text = text.split()
        elif lang not in self.lang_with_custom_tokenizer:
            text = self.moses_pipeline(text, lang=lang)
            # TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
            if lang == "ro":
                text = romanian_preprocessing(text)
            text = self.moses_tokenize(text, lang=lang)
        elif lang == "th":
            text = self.moses_pipeline(text, lang=lang)
            try:
                if "pythainlp" not in sys.modules:
                    from pythainlp.tokenize import word_tokenize as th_word_tokenize
                else:
                    th_word_tokenize = sys.modules["pythainlp"].word_tokenize
            except (AttributeError, ImportError):
                logger.error(
                    "Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
                )
                logger.error("1. pip install pythainlp")
                raise
            text = th_word_tokenize(text)
        elif lang == "zh":
            try:
                if "jieba" not in sys.modules:
                    import jieba
                else:
                    jieba = sys.modules["jieba"]
            except (AttributeError, ImportError):
                logger.error(
                    "Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps"
                )
                logger.error("1. pip install jieba")
                raise
            text = " ".join(jieba.cut(text))
            text = self.moses_pipeline(text, lang=lang)
            text = text.split()
        elif lang == "ja":
            text = self.moses_pipeline(text, lang=lang)
            text = self.ja_tokenize(text)
        else:
            raise ValueError("It should not reach here")

        if self.do_lowercase_and_remove_accent and not bypass_tokenizer:
            text = lowercase_and_remove_accent(text)

        split_tokens = []
        for token in text:
            if token:
                split_tokens.extend([t for t in self.bpe(token).split(" ")])

        return split_tokens
Exemplo n.º 13
0
    def ja_tokenize(self, text):
        if self.ja_word_tokenizer is None:
            try:
                import Mykytea

                self.ja_word_tokenizer = Mykytea.Mykytea(
                    "-model %s/local/share/kytea/model.bin" %
                    os.path.expanduser("~"))
            except (AttributeError, ImportError):
                logger.error(
                    "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
                )
                logger.error(
                    "1. git clone [email protected]:neubig/kytea.git && cd kytea")
                logger.error("2. autoreconf -i")
                logger.error("3. ./configure --prefix=$HOME/local")
                logger.error("4. make && make install")
                logger.error("5. pip install kytea")
                raise
        return list(self.ja_word_tokenizer.getWS(text))
Exemplo n.º 14
0
def load_tf_weights_in_bert_generation(model,
                                       tf_hub_path,
                                       model_class,
                                       is_encoder_named_decoder=False,
                                       is_encoder=False):
    try:
        import numpy as np
        import tensorflow.compat.v1 as tf

        import tensorflow_hub as hub
        import tensorflow_text  # noqa: F401

        tf.disable_eager_execution()
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    tf_model = hub.Module(tf_hub_path)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        init.run()
        all_variables = tf_model.variable_map
        keep_track_variables = all_variables.copy()
        for key in list(all_variables.keys()):
            if "global" in key:
                logger.info(f"Skipping {key}...")
                continue
            if not is_encoder:
                model_pointer = getattr(model, model_class)
            else:
                model_pointer = model
            is_embedding = False
            logger.info(f"Trying to match {key}...")
            # remove start_string = "module/bert/"
            sub_layers = key.split("/")[2:]
            if is_encoder_named_decoder and sub_layers[0] == "encoder":
                logger.info(f"Skipping encoder layer {key} for decoder")
                continue
            if is_encoder and sub_layers[0] == "decoder":
                logger.info(f"Skipping decoder layer {key} for encoder")
                continue
            for i, sub_layer in enumerate(sub_layers):
                if sub_layer == "embeddings":
                    is_embedding = True
                elif sub_layer == "LayerNorm":
                    is_embedding = False
                if "layer" in sub_layer:
                    model_pointer = model_pointer.layer[int(
                        sub_layer.split("_")[-1])]
                elif sub_layer in ["kernel", "gamma"]:
                    model_pointer = model_pointer.weight
                elif sub_layer == "beta":
                    model_pointer = model_pointer.bias
                elif sub_layer == "encdec":
                    model_pointer = model_pointer.crossattention.self
                elif sub_layer == "encdec_output":
                    model_pointer = model_pointer.crossattention.output
                elif is_encoder_named_decoder and sub_layer == "decoder":
                    model_pointer = model_pointer.encoder
                else:
                    if sub_layer == "attention" and "encdec" in sub_layers[i +
                                                                           1]:
                        continue
                    try:
                        model_pointer = getattr(model_pointer, sub_layer)
                    except AttributeError:
                        logger.info(
                            f"Skipping to initialize {key} at {sub_layer}...")
                        raise AttributeError

            array = np.asarray(sess.run(all_variables[key]))
            if not is_embedding:
                logger.info(
                    "Transposing numpy weight of shape {} for {}".format(
                        array.shape, key))
                array = np.transpose(array)
            else:
                model_pointer = model_pointer.weight

            try:
                assert (
                    model_pointer.shape == array.shape
                ), f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched"
            except AssertionError as e:
                e.args += (model_pointer.shape, array.shape)
                raise
            logger.info(f"Initialize PyTorch weight {key}")

            model_pointer.data = torch.from_numpy(array.astype(np.float32))
            keep_track_variables.pop(key, None)

        logger.info("Weights not copied to PyTorch model: {}".format(", ".join(
            keep_track_variables.keys())))
        return model
Exemplo n.º 15
0
def train(arch, train_path, batch_size, embed_size, hidden_size, dropout,
          epochs, src_vocab_path, trg_vocab_path, model_path, max_length,
          use_segment):
    arch = arch.lower()
    if arch in ['seq2seq', 'convseq2seq']:
        source_texts, target_texts = create_dataset(train_path, None)

        src_2_ids = read_vocab(source_texts)
        trg_2_ids = read_vocab(target_texts)
        save_word_dict(src_2_ids, src_vocab_path)
        save_word_dict(trg_2_ids, trg_vocab_path)
        src_2_ids = load_word_dict(src_vocab_path)
        trg_2_ids = load_word_dict(trg_vocab_path)

        id_2_srcs = {v: k for k, v in src_2_ids.items()}
        id_2_trgs = {v: k for k, v in trg_2_ids.items()}
        train_src, train_trg = one_hot(source_texts,
                                       target_texts,
                                       src_2_ids,
                                       trg_2_ids,
                                       sort_by_len=True)

        k = 0
        print('src:', ' '.join([id_2_srcs[i] for i in train_src[k]]))
        print('trg:', ' '.join([id_2_trgs[i] for i in train_trg[k]]))

        train_data = gen_examples(train_src, train_trg, batch_size, max_length)

        if arch == 'seq2seq':
            # Normal seq2seq
            model = Seq2Seq(encoder_vocab_size=len(src_2_ids),
                            decoder_vocab_size=len(trg_2_ids),
                            embed_size=embed_size,
                            enc_hidden_size=hidden_size,
                            dec_hidden_size=hidden_size,
                            dropout=dropout).to(device)
            print(model)
            loss_fn = LanguageModelCriterion().to(device)
            optimizer = torch.optim.Adam(model.parameters())

            train_seq2seq_model(model,
                                train_data,
                                device,
                                loss_fn,
                                optimizer,
                                model_path,
                                epochs=epochs)
        else:
            # Conv seq2seq model
            trg_pad_idx = trg_2_ids[PAD_TOKEN]
            model = ConvSeq2Seq(encoder_vocab_size=len(src_2_ids),
                                decoder_vocab_size=len(trg_2_ids),
                                embed_size=embed_size,
                                enc_hidden_size=hidden_size,
                                dec_hidden_size=hidden_size,
                                dropout=dropout,
                                trg_pad_idx=trg_pad_idx,
                                device=device,
                                max_length=max_length).to(device)
            print(model)
            loss_fn = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
            optimizer = torch.optim.Adam(model.parameters())

            train_convseq2seq_model(model,
                                    train_data,
                                    device,
                                    loss_fn,
                                    optimizer,
                                    model_path,
                                    epochs=epochs)
    elif arch == 'bertseq2seq':
        # Bert Seq2seq model
        model_args = {
            "reprocess_input_data": True,
            "overwrite_output_dir": True,
            "max_seq_length": max_length if max_length else 128,
            "train_batch_size": batch_size if batch_size else 8,
            "num_train_epochs": epochs if epochs else 10,
            "save_eval_checkpoints": False,
            "save_model_every_epoch": False,
            "silent": False,
            "evaluate_generated_text": True,
            "evaluate_during_training": False,
            "evaluate_during_training_verbose": False,
            "use_multiprocessing": False,
            "save_best_model": True,
            "max_length": max_length if max_length else
            128,  # The maximum length of the sequence to be generated.
            "output_dir": "./output/bertseq2seq/",
        }

        # encoder_type=None, encoder_name=None, decoder_name=None, encoder_decoder_type=None, encoder_decoder_name=None,
        use_cuda = True if torch.cuda.is_available() else False
        model = Seq2SeqModel("bert",
                             "bert-base-chinese",
                             "bert-base-chinese",
                             args=model_args,
                             use_cuda=use_cuda)

        print('start train bertseq2seq ...')
        data = load_bert_data(train_path, use_segment)
        train_data, dev_data = train_test_split(data,
                                                test_size=0.1,
                                                shuffle=True)

        train_df = pd.DataFrame(train_data,
                                columns=['input_text', 'target_text'])
        dev_df = pd.DataFrame(dev_data, columns=['input_text', 'target_text'])

        model.train_model(train_df, eval_data=dev_df)
    else:
        logger.error('error arch: {}'.format(arch))
        raise ValueError(
            "Model arch choose error. Must use one of seq2seq model.")