def load_finalfusion(file: Union[str, bytes, int, PathLike],
                     mmap: bool = False) -> Embeddings:
    """
    Read embeddings from a file in finalfusion format.

    Parameters
    ----------
    file : str, bytes, int, PathLike
        Path to a file with embeddings in finalfusoin format.
    mmap : bool
        Toggles memory mapping the storage buffer.

    Returns
    -------
    embeddings : Embeddings
        The embeddings from the input file.
    """
    with open(file, 'rb') as inf:
        _ = Header.read_chunk(inf)
        chunk_id, _ = _read_required_chunk_header(inf)
        norms = None
        metadata = None

        if chunk_id == ChunkIdentifier.Metadata:
            metadata = Metadata.read_chunk(inf)
            chunk_id, _ = _read_required_chunk_header(inf)

        if chunk_id == ChunkIdentifier.SimpleVocab:
            vocab = SimpleVocab.read_chunk(inf)  # type: Vocab
        elif chunk_id == ChunkIdentifier.BucketSubwordVocab:
            vocab = FinalfusionBucketVocab.read_chunk(inf)
        elif chunk_id == ChunkIdentifier.FastTextSubwordVocab:
            vocab = FastTextVocab.read_chunk(inf)
        elif chunk_id == ChunkIdentifier.ExplicitSubwordVocab:
            vocab = ExplicitVocab.read_chunk(inf)
        else:
            raise FinalfusionFormatError(
                f'Expected vocab chunk, not {str(chunk_id)}')

        chunk_id, _ = _read_required_chunk_header(inf)
        if chunk_id == ChunkIdentifier.NdArray:
            storage = NdArray.load(inf, mmap)  # type: Storage
        elif chunk_id == ChunkIdentifier.QuantizedArray:
            storage = QuantizedArray.load(inf, mmap)
        else:
            raise FinalfusionFormatError(
                f'Expected storage chunk, not {str(chunk_id)}')
        maybe_chunk_id = _read_chunk_header(inf)
        if maybe_chunk_id is not None:
            if maybe_chunk_id[0] == ChunkIdentifier.NdNorms:
                norms = Norms.read_chunk(inf)
            else:
                raise FinalfusionFormatError(
                    f'Expected norms chunk, not {str(chunk_id)}')

        return Embeddings(storage, vocab, norms, metadata, inf.name)
Beispiel #2
0
    def _read_array_header(file: BinaryIO) -> Tuple[int, int]:
        """
        Helper method to read the header of an NdArray chunk.

        The method reads the shape tuple, verifies the TypeId and seeks the file to the start
        of the array. The shape tuple is returned.

        Parameters
        ----------
        file : BinaryIO
            finalfusion file with a storage at the start of a NdArray chunk.

        Returns
        -------
        shape : Tuple[int, int]
            Shape of the storage.

        Raises
        ------
        FinalfusionFormatError
            If the TypeId does not match TypeId.f32
        """
        rows, cols = _read_required_binary(file, "<QI")
        type_id = TypeId(_read_required_binary(file, "<I")[0])
        if TypeId.f32 != type_id:
            raise FinalfusionFormatError(
                f"Invalid Type, expected {TypeId.f32}, got {type_id}")
        file.seek(_pad_float32(file.tell()), 1)
        return rows, cols
 def _read_quantized_header(
         file: BinaryIO
 ) -> Tuple[PQ, Tuple[int, int], Optional[np.ndarray]]:
     """
     Helper method to read the header of a quantized array chunk.
     Returns a tuple containing PQ, quantized_shape and optional norms.
     """
     projection = _read_required_binary(file, '<I')[0] != 0
     read_norms = _read_required_binary(file, '<I')[0] != 0
     quantized_len = _read_required_binary(file, '<I')[0]
     reconstructed_len = _read_required_binary(file, '<I')[0]
     n_centroids = _read_required_binary(file, '<I')[0]
     n_embeddings = _read_required_binary(file, '<Q')[0]
     assert reconstructed_len % quantized_len == 0
     type_id = _read_required_binary(file, '<I')[0]
     if int(TypeId.u8) != type_id:
         raise FinalfusionFormatError(
             f"Invalid Type, expected {str(TypeId.u8)}, got {type_id}")
     type_id = _read_required_binary(file, '<I')[0]
     if int(TypeId.f32) != type_id:
         raise FinalfusionFormatError(
             f"Invalid Type, expected {str(TypeId.f32)}, got {type_id}")
     file.seek(_pad_float32(file.tell()), 1)
     if projection:
         projection = _read_array_as_native(file, np.float32,
                                            reconstructed_len**2)
         projection_shape = (reconstructed_len, reconstructed_len)
         projection = projection.reshape(projection_shape)
     else:
         projection = None
     quantizer_shape = (quantized_len, n_centroids,
                        reconstructed_len // quantized_len)
     quantizers_size = quantized_len * n_centroids * (reconstructed_len //
                                                      quantized_len)
     quantizers = _read_array_as_native(file, np.float32, quantizers_size)
     quantizers = quantizers.reshape(quantizer_shape)
     if read_norms:
         norms = _read_array_as_native(file, np.float32, n_embeddings)
     else:
         norms = None
     quantizer = PQ(quantizers, projection)
     return quantizer, (n_embeddings, quantized_len), norms
 def read_chunk(file: BinaryIO) -> 'Norms':
     n_norms, dtype = _read_required_binary(file, "<QI")
     type_id = TypeId(dtype)
     if TypeId.f32 != type_id:
         raise FinalfusionFormatError(
             f"Invalid Type, expected {TypeId.f32}, got {str(type_id)}")
     padding = _pad_float32(file.tell())
     file.seek(padding, 1)
     array = np.fromfile(file=file, count=n_norms, dtype=np.float32)
     if sys.byteorder == "big":
         array.byteswap(inplace=True)
     return Norms(array)
Beispiel #5
0
 def read_chunk(file: BinaryIO) -> 'Metadata':
     chunk_header_size = struct.calcsize("<IQ")
     # place the file before the chunk header since the chunk size for
     # metadata the number of bytes that we need to read
     file.seek(-chunk_header_size, 1)
     chunk_id, chunk_len = _read_required_binary(file, "<IQ")
     assert ChunkIdentifier(chunk_id) == Metadata.chunk_identifier()
     buf = file.read(chunk_len)
     if len(buf) != chunk_len:
         raise FinalfusionFormatError(
             f'Could not read {chunk_len} bytes from file')
     return Metadata(toml.loads(buf.decode("utf-8")))