Пример #1
0
def assert_packed_msg_equal(b1, b2):
    """Assert that two packed msgpack messages are equal."""
    msg1 = srsly.msgpack_loads(b1)
    msg2 = srsly.msgpack_loads(b2)
    assert sorted(msg1.keys()) == sorted(msg2.keys())
    for (k1, v1), (k2, v2) in zip(sorted(msg1.items()), sorted(msg2.items())):
        assert k1 == k2
        assert v1 == v2
Пример #2
0
def from_bytes(bytes_data, setters, exclude):
    msg = srsly.msgpack_loads(bytes_data)
    for key, setter in setters.items():
        # Split to support file names like meta.json
        if key.split(".")[0] not in exclude and key in msg:
            setter(msg[key])
    return msg
Пример #3
0
 def from_bytes(self, bytes_data):
     data = srsly.msgpack_loads(bytes_data)
     weights = data[b"weights"]
     queue = [self]
     i = 0
     for layer in queue:
         # Hack to support saving/loading PyTorch models. TODO: Improve
         if hasattr(layer, "_model") and not isinstance(layer._model, Model):
             layer.from_bytes(weights[i])
             i += 1
         elif hasattr(layer, "_mem"):
             if b"seed" in weights[i]:
                 layer.seed = weights[i][b"seed"]
             for dim, value in weights[i][b"dims"].items():
                 if isinstance(dim, bytes):
                     dim = dim.decode("utf8")
                 setattr(layer, dim, value)
             for param in weights[i][b"params"]:
                 name = param[b"name"]
                 if isinstance(name, bytes):
                     name = name.decode("utf8")
                 dest = getattr(layer, name)
                 copy_array(dest, param[b"value"])
             i += 1
         if hasattr(layer, "_layers"):
             queue.extend(layer._layers)
     return self
Пример #4
0
    def from_bytes(self, bytes_data: bytes) -> "Model":
        """Deserialize the model from a bytes representation. Models are usually
        serialized using msgpack, so you should be able to call msgpack.loads()
        on the data and get back a dictionary with the contents.

        Serialization should round-trip identically, i.e. the same bytes should
        result from loading and serializing a model.
        """
        msg = srsly.msgpack_loads(bytes_data)
        msg = convert_recursive(is_xp_array, self.ops.asarray, msg)
        return self.from_dict(msg)
Пример #5
0
    def from_bytes(self, patterns_bytes, **kwargs):
        """Load the entity ruler from a bytestring.

        patterns_bytes (bytes): The bytestring to load.
        **kwargs: Other config paramters, mostly for consistency.
        RETURNS (EntityRuler): The loaded entity ruler.

        DOCS: https://spacy.io/api/entityruler#from_bytes
        """
        patterns = srsly.msgpack_loads(patterns_bytes)
        self.add_patterns(patterns)
        return self
Пример #6
0
    def from_bytes(self, bytes_data: bytes, **kwargs) -> "Lookups":
        """Load the lookups from a bytestring.

        bytes_data (bytes): The data to load.
        RETURNS (Lookups): The loaded Lookups.

        DOCS: https://spacy.io/api/lookups#from_bytes
        """
        self._tables = {}
        for key, value in srsly.msgpack_loads(bytes_data).items():
            self._tables[key] = Table(key, value)
        return self
Пример #7
0
    def from_bytes(self, patterns_bytes, **kwargs):
        """Load the entity ruler from a bytestring.

        patterns_bytes (bytes): The bytestring to load.
        **kwargs: Other config paramters, mostly for consistency.
        RETURNS (EntityRuler): The loaded entity ruler.

        DOCS: https://spacy.io/api/entityruler#from_bytes
        """
        patterns = srsly.msgpack_loads(patterns_bytes)
        self.add_patterns(patterns)
        return self
Пример #8
0
 def can_from_bytes(self,
                    bytes_data: bytes,
                    *,
                    strict: bool = True) -> bool:
     """Check whether the bytes data is compatible with the model. If 'strict',
     the function returns False if the model has an attribute already loaded
     that would be changed.
     """
     try:
         msg = srsly.msgpack_loads(bytes_data)
     except ValueError:
         return False
     return self.can_from_dict(msg, strict=strict)
Пример #9
0
 def get_docs(self, vocab):
     """Recover Doc objects from the annotations, using the given vocab."""
     for string in self.strings:
         vocab[string]
     orth_col = self.attrs.index(ORTH)
     for i in range(len(self.tokens)):
         tokens = self.tokens[i]
         spaces = self.spaces[i]
         words = [vocab.strings[orth] for orth in tokens[:, orth_col]]
         doc = Doc(vocab, words=words, spaces=spaces)
         doc = doc.from_array(self.attrs, tokens)
         if self.store_user_data:
             doc.user_data.update(srsly.msgpack_loads(self.user_data[i]))
         yield doc
Пример #10
0
 def from_bytes(self, bytes_data):
     ops = get_current_ops()
     msg = srsly.msgpack_loads(bytes_data)
     self.cfg = msg["config"]
     filelike = BytesIO(msg["state"])
     filelike.seek(0)
     if ops.device_type == "cpu":
         map_location = "cpu"
     else:  # pragma: no cover
         device_id = torch.cuda.current_device()
         map_location = "cuda:%d" % device_id
     self._model.load_state_dict(torch.load(filelike, map_location=map_location))
     self._model.to(map_location)
     return self
Пример #11
0
    def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()):
        """Load a Sense2Vec object from a bytestring.

        bytes_data (bytes): The data to load.
        exclude (list): Names of serialization fields to exclude.
        RETURNS (Sense2Vec): The loaded object.
        """
        data = srsly.msgpack_loads(bytes_data)
        self.vectors = Vectors().from_bytes(data["vectors"])
        self.freqs = dict(data.get("freqs", []))
        self.cfg.update(data.get("cfg", {}))
        if "strings" not in exclude and "strings" in data:
            self.strings = StringStore().from_bytes(data["strings"])
        return self
Пример #12
0
    def from_bytes(self, bytes_data: bytes) -> "Table":
        """Load a table from a bytestring.

        bytes_data (bytes): The data to load.
        RETURNS (Table): The loaded table.

        DOCS: https://spacy.io/api/lookups#table.from_bytes
        """
        loaded = srsly.msgpack_loads(bytes_data)
        data = loaded.get("dict", {})
        self.name = loaded["name"]
        self.bloom = BloomFilter().from_bytes(loaded["bloom"])
        self.clear()
        self.update(data)
        return self
Пример #13
0
 def from_bytes(self, string):
     """Deserialize the binder's annotations from a byte string."""
     msg = srsly.msgpack_loads(gzip.decompress(string))
     self.attrs = msg["attrs"]
     self.strings = set(msg["strings"])
     lengths = numpy.fromstring(msg["lengths"], dtype="int32")
     flat_spaces = numpy.fromstring(msg["spaces"], dtype=bool)
     flat_tokens = numpy.fromstring(msg["tokens"], dtype="uint64")
     shape = (flat_tokens.size // len(self.attrs), len(self.attrs))
     flat_tokens = flat_tokens.reshape(shape)
     flat_spaces = flat_spaces.reshape((flat_spaces.size, 1))
     self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
     self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
     for tokens in self.tokens:
         assert len(tokens.shape) == 2, tokens.shape
     return self
Пример #14
0
 def from_bytes(self, string):
     """Deserialize the binder's annotations from a byte string."""
     msg = srsly.msgpack_loads(gzip.decompress(string))
     self.attrs = msg["attrs"]
     self.strings = set(msg["strings"])
     lengths = numpy.fromstring(msg["lengths"], dtype="int32")
     flat_spaces = numpy.fromstring(msg["spaces"], dtype=bool)
     flat_tokens = numpy.fromstring(msg["tokens"], dtype="uint64")
     shape = (flat_tokens.size // len(self.attrs), len(self.attrs))
     flat_tokens = flat_tokens.reshape(shape)
     flat_spaces = flat_spaces.reshape((flat_spaces.size, 1))
     self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
     self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
     for tokens in self.tokens:
         assert len(tokens.shape) == 2, tokens.shape
     return self
Пример #15
0
    def from_bytes(self, bytes_data):
        msg = srsly.msgpack_loads(bytes_data)
        config_dict = msg["config"]
        tok_dict = msg["tokenizer"]
        if config_dict:
            with make_tempdir() as temp_dir:
                config_file = temp_dir / "config.json"
                srsly.write_json(config_file, config_dict)
                config = AutoConfig.from_pretrained(config_file)
                for x, x_bytes in tok_dict.items():
                    Path(temp_dir / x).write_bytes(x_bytes)
                tokenizer = AutoTokenizer.from_pretrained(str(temp_dir.absolute()))
                vocab_file_contents = None
                if hasattr(tokenizer, "vocab_file"):
                    vocab_file_name = tokenizer.vocab_files_names["vocab_file"]
                    vocab_file_path = str((temp_dir / vocab_file_name).absolute())
                    with open(vocab_file_path, "rb") as fileh:
                        vocab_file_contents = fileh.read()

            transformer = AutoModel.from_config(config)
            self._hfmodel = HFObjects(
                tokenizer,
                transformer,
                vocab_file_contents,
                SimpleFrozenDict(),
                SimpleFrozenDict(),
            )
            self._model = transformer
            filelike = BytesIO(msg["state"])
            filelike.seek(0)
            ops = get_current_ops()
            if ops.device_type == "cpu":
                map_location = "cpu"
            else:  # pragma: no cover
                device_id = torch.cuda.current_device()
                map_location = f"cuda:{device_id}"
            self._model.load_state_dict(torch.load(filelike, map_location=map_location))
            self._model.to(map_location)
        else:
            self._hfmodel = HFObjects(
                None,
                None,
                None,
                msg["_init_tokenizer_config"],
                msg["_init_transformer_config"],
            )
        return self
Пример #16
0
    def load(cls, lang, filepath):
        """
        Load previously saved :class:`Corpus` binary data, reproduce the original
        `:class:`spacy.tokens.Doc`s tokens and annotations, and instantiate
        a new :class:`Corpus` from them.

        Args:
            lang (str or :class:`spacy.language.Language`)
            filepath (str): Full path to file on disk where :class:`Corpus` data
                was previously saved as a binary file.

        Returns:
            :class:`Corpus`

        See Also:
            :meth:`Corpus.save()`
        """
        spacy_lang = _get_spacy_lang(lang)
        with tio.open_sesame(filepath, mode="rb") as f:
            msg = srsly.msgpack_loads(f.read())
        if spacy_lang.meta != msg["meta"]:
            LOGGER.warning("the spacy langs are different!")
        for string in msg["strings"]:
            spacy_lang.vocab[string]
        attrs = msg["attrs"]
        lengths = np.frombuffer(msg["lengths"], dtype="int32")
        flat_tokens = np.frombuffer(msg["tokens"], dtype="uint64")
        flat_tokens = flat_tokens.reshape(
            (flat_tokens.size // len(attrs), len(attrs)))
        tokens = np.asarray(NumpyOps().unflatten(flat_tokens, lengths))
        user_datas = msg["user_datas"]

        def _make_spacy_docs(tokens, user_datas):
            for toks, user_data in compat.zip_(tokens, user_datas):
                doc = spacy.tokens.Doc(
                    spacy_lang.vocab,
                    words=[
                        spacy_lang.vocab.strings[orth] for orth in toks[:, 0]
                    ],
                    spaces=np.ndarray.tolist(toks[:, 1]),
                )
                doc = doc.from_array(attrs[2:], toks[:, 2:])
                doc.user_data = user_data
                yield doc

        return cls(spacy_lang, data=_make_spacy_docs(tokens, user_datas))
Пример #17
0
    def from_bytes(self, patterns_bytes, **kwargs):
        """Load the entity ruler from a bytestring.

        patterns_bytes (bytes): The bytestring to load.
        **kwargs: Other config paramters, mostly for consistency.
        RETURNS (EntityRuler): The loaded entity ruler.

        DOCS: https://spacy.io/api/entityruler#from_bytes
        """
        cfg = srsly.msgpack_loads(patterns_bytes)
        if isinstance(cfg, dict):
            self.add_patterns(cfg.get('patterns', cfg))
            self.overwrite = cfg.get('overwrite', False)
            self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
        else:
            self.add_patterns(cfg)
        return self
Пример #18
0
    def from_bytes(self, patterns_bytes, **kwargs):
        """Load the entity ruler from a bytestring.

        patterns_bytes (bytes): The bytestring to load.
        **kwargs: Other config paramters, mostly for consistency.
        RETURNS (EntityRuler): The loaded entity ruler.

        DOCS: https://spacy.io/api/entityruler#from_bytes
        """
        cfg = srsly.msgpack_loads(patterns_bytes)
        if isinstance(cfg, dict):
            self.add_patterns(cfg.get("patterns", cfg))
            self.overwrite = cfg.get("overwrite", False)
            self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
            if self.phrase_matcher_attr is not None:
                self.phrase_matcher = PhraseMatcher(
                    self.nlp.vocab, attr=self.phrase_matcher_attr)
            self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
        else:
            self.add_patterns(cfg)
        return self
Пример #19
0
    def get_docs(self, vocab):
        """Recover Doc objects from the annotations, using the given vocab.

        vocab (Vocab): The shared vocab.
        YIELDS (Doc): The Doc objects.

        DOCS: https://spacy.io/api/docbin#get_docs
        """
        for string in self.strings:
            vocab[string]
        orth_col = self.attrs.index(ORTH)
        for i in range(len(self.tokens)):
            tokens = self.tokens[i]
            spaces = self.spaces[i]
            words = [vocab.strings[orth] for orth in tokens[:, orth_col]]
            doc = Doc(vocab, words=words, spaces=spaces)
            doc = doc.from_array(self.attrs, tokens)
            if self.store_user_data:
                user_data = srsly.msgpack_loads(self.user_data[i],
                                                use_list=False)
                doc.user_data.update(user_data)
            yield doc
Пример #20
0
    def from_bytes(self, serial: bytes, **kwargs):
        """Load waterwheel from a bytestring.

        Parameters
        ----------
        serial : bytes
            The serialized bytes data.
        
        Returns
        -------
        self : WaterWheel
            The loaded WaterWheel object.
        """

        cfg = srsly.msgpack_loads(serial)
        if isinstance(cfg, dict):
            vocab = cfg.get('vocab', {})
            for hash, label in vocab.items():
                self._ent_ids[int(hash)] = label
                self._qualifiers[label] = [label.lower(), label.lower() + 's']
            self._qualifiers['MOUNTAIN'].extend(['mount', 'mounts', 'mt.'])
            self._stop_words = cfg.get('stop_words', [])
            self._stop_words = set(self._stop_words)
            self._wikidata = cfg.get('wikidata', {})

            doc_bins_bytes = cfg.get('doc_bins', {})
            self._doc_bins = {
                key: DocBin().from_bytes(value)
                for key, value in doc_bins_bytes.items()
            }
            phrases_bin = {
                key: list(bin.get_docs(self.nlp.vocab))
                for key, bin in self._doc_bins.items()
            }
            for key, phrases in phrases_bin.items():
                self.phrase_matcher.add(key.upper(), phrases)
        return self
Пример #21
0
    def from_bytes(
        self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
    ) -> "EntityRuler":
        """Load the entity ruler from a bytestring.

        patterns_bytes (bytes): The bytestring to load.
        RETURNS (EntityRuler): The loaded entity ruler.

        DOCS: https://spacy.io/api/entityruler#from_bytes
        """
        cfg = srsly.msgpack_loads(patterns_bytes)
        self.clear()
        if isinstance(cfg, dict):
            self.add_patterns(cfg.get("patterns", cfg))
            self.overwrite = cfg.get("overwrite", False)
            self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
            if self.phrase_matcher_attr is not None:
                self.phrase_matcher = PhraseMatcher(
                    self.nlp.vocab, attr=self.phrase_matcher_attr
                )
            self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
        else:
            self.add_patterns(cfg)
        return self
    def from_bytes(self, bytes_data):
        """Deserialize the DocBin's annotations from a bytestring.

        bytes_data (bytes): The data to load from.
        RETURNS (DocBin): The loaded DocBin.

        DOCS: https://spacy.io/api/docbin#from_bytes
        """
        msg = srsly.msgpack_loads(zlib.decompress(bytes_data))
        self.attrs = msg["attrs"]
        self.strings = set(msg["strings"])
        lengths = numpy.fromstring(msg["lengths"], dtype="int32")
        flat_spaces = numpy.fromstring(msg["spaces"], dtype=bool)
        flat_tokens = numpy.fromstring(msg["tokens"], dtype="uint64")
        shape = (flat_tokens.size // len(self.attrs), len(self.attrs))
        flat_tokens = flat_tokens.reshape(shape)
        flat_spaces = flat_spaces.reshape((flat_spaces.size, 1))
        self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
        self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
        if self.store_user_data and "user_data" in msg:
            self.user_data = list(msg["user_data"])
        for tokens in self.tokens:
            assert len(tokens.shape) == 2, tokens.shape  # this should never happen
        return self
Пример #23
0
 def deserialize_pkuseg_processors(b):
     nonlocal pkuseg_processors_data
     pkuseg_processors_data = srsly.msgpack_loads(b)
Пример #24
0
 def load_patterns(b):
     self.add_patterns(srsly.msgpack_loads(b))
Пример #25
0
 def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
     msg = srsly.msgpack_loads(bytes_data)
     for field in self.serialization_fields:
         setattr(self, field, msg[field])
     self.finish_deserializing()
     return self
Пример #26
0
def test_serialize_transformer_data():
    data = {"x": TransformerData.empty()}
    bytes_data = srsly.msgpack_dumps(data)
    new_data = srsly.msgpack_loads(bytes_data)
    assert isinstance(new_data["x"], TransformerData)
Пример #27
0
def deserialize_attr(_: Any, value: Any, name: str, model: Model) -> Any:
    """Deserialize an attribute value (defaults to msgpack). You can register
    custom deserializers using the @deserialize_attr.register decorator with the
    type to deserialize, e.g.: @deserialize_attr.register(MyCustomObject).
    """
    return srsly.msgpack_loads(value)
Пример #28
0
 def from_bytes(self, byte_string: bytes) -> "TransformerData":
     msg = srsly.msgpack_loads(byte_string)
     self.from_dict(msg)
     return self
Пример #29
0
 def deserialize_pkuseg_processors(b):
     pkuseg_data["processors_data"] = srsly.msgpack_loads(b)
Пример #30
0
 def from_bytes(self, bytes_data):
     msg = srsly.msgpack_loads(bytes_data)
     self.cfg = msg["config"]
     self._load_params(msg["state"])
     return self