コード例 #1
0
    def shallow_clone(self) -> Trs:
        new_trs_ptr = c_void_p()
        exit_code = lib.trs_vec_shallow_clone(self._ptr, new_trs_ptr)
        err_msg = "Something went wrong when cloning Trs"
        check_ffi_error(exit_code, err_msg)

        return Trs(new_trs_ptr)
コード例 #2
0
def top_sort(fst: VectorFst) -> VectorFst:
    """
    This operation topologically sorts its input. When sorted, all transitions are from lower to higher state IDs.

    Examples :

    - Input

    ![topsort_in](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/topsort_in.svg?sanitize=true)

    - Output

    ![topsort_out](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/topsort_out.svg?sanitize=true)

    Args:
        fst: Fst to top_sort.
    Returns:
        Equivalent top sorted Fst. Modification also happens in-place.
    """

    top_sorted_fst = ctypes.c_void_p()
    ret_code = lib.fst_top_sort(fst.ptr, ctypes.byref(top_sorted_fst))
    err_msg = "Error during top_sort"
    check_ffi_error(ret_code, err_msg)

    return VectorFst(ptr=top_sorted_fst)
コード例 #3
0
 def delete_states(self):
     """
     Delete all the states
     """
     ret_code = lib.vec_fst_delete_states(self.ptr)
     err_msg = "Error deleting states"
     check_ffi_error(ret_code, err_msg)
コード例 #4
0
 def remove(self, index: int) -> Tr:
     removed_tr = c_void_p()
     exit_code = lib.trs_vec_remove(self._ptr, index, byref(removed_tr))
     err_msg = "Something went wrong when removing transition at index: " + str(
         index)
     check_ffi_error(exit_code, err_msg)
     return Tr(removed_tr)
コード例 #5
0
ファイル: __init__.py プロジェクト: Garvys/rustfst
    def set_output_symbols(self, syms: Optional[SymbolTable]) -> Fst:
        """
        Sets the output symbol table.
        Passing None as a value will delete the output symbol table.
        Args:
          syms: A SymbolTable.
        Returns:
          self.
        See also: `set_input_symbols`.
        """
        if syms is None:
            ret_code = lib.fst_unset_output_symbols(self.ptr)
            err_msg = "Error unsetting output symbols"
            check_ffi_error(ret_code, err_msg)
            # detach symbol table from fst
            self._output_symbols = None
            return self

        table = syms.ptr

        ret_code = lib.fst_set_output_symbols(self.ptr, table)
        err_msg = "Error setting output symbols"
        check_ffi_error(ret_code, err_msg)

        # attach symbol table to fst (prevent early gc of syms)
        self._output_symbols = syms

        return self
コード例 #6
0
    def __init__(
        self,
        ilabel: Optional[int] = None,
        olabel: Optional[int] = None,
        weight: Optional[float] = None,
        nextstate: Optional[int] = None,
    ):
        """
        Create a new transition.

        Args:
            ilabel: The input label.
            olabel: The outpit label.
            weight: The transition's weight
            nextstate: The destination state for the transition.
        """
        if ilabel and olabel is None and weight is None and nextstate is None:
            self._ptr = ilabel
        else:
            if weight is None:
                weight = weight_one()

            ptr = c_void_p()
            exit_code = lib.tr_new(
                c_size_t(ilabel),
                c_size_t(olabel),
                c_float(weight),
                c_size_t(nextstate),
                byref(ptr),
            )
            err_msg = "Something went wrong when creating the Tr struct"
            check_ffi_error(exit_code, err_msg)

            self._ptr = ptr
コード例 #7
0
    def __str__(self):
        s = ctypes.c_void_p()
        ret_code = lib.const_fst_display(self.ptr, ctypes.byref(s))
        err_msg = "Error displaying ConstFst"
        check_ffi_error(ret_code, err_msg)

        return ctypes.string_at(s).decode("utf8")
コード例 #8
0
ファイル: connect.py プロジェクト: Garvys/rustfst
def connect(fst: VectorFst) -> VectorFst:
    """
    This operation trims an Fst, removing states and trs that are not on successful paths.

    Examples :

    - Input :

    ![connect_in](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/connect_in.svg?sanitize=true)

    - Output :

    ![connect_out](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/connect_out.svg?sanitize=true)

    Returns :
        self

    """

    connectd_fst = ctypes.c_void_p()
    ret_code = lib.fst_connect(fst.ptr, ctypes.byref(connectd_fst))
    err_msg = "Error during connect"
    check_ffi_error(ret_code, err_msg)

    return VectorFst(ptr=connectd_fst)
コード例 #9
0
    def member(self, key: Union[int, str]) -> bool:
        """
        Given a symbol or index, returns whether it is found in the table.
        This method returns a boolean indicating whether the given symbol or index
        is present in the table. If one intends to perform subsequent lookup, it is
        better to simply call the find method, catching the KeyError.
        Args:
          key: Either a string or an index.
        Returns:
          Whether or not the key is present (as a string or a index) in the table.
        """
        is_present = ctypes.c_size_t()

        ret_code = None

        if isinstance(key, int):
            index = ctypes.c_size_t(key)
            ret_code = lib.symt_member_index(self.ptr, index,
                                             ctypes.byref(is_present))
        elif isinstance(key, str):
            symbol = key.encode("utf-8")
            ret_code = lib.symt_member_symbol(self.ptr, symbol,
                                              ctypes.byref(is_present))
        else:
            raise f"key can only be a string or integer. Not {type(key)}"

        err_msg = "`member` failed"
        check_ffi_error(ret_code, err_msg)

        return bool(is_present.value)
コード例 #10
0
def union(fst: VectorFst, other_fst: VectorFst) -> VectorFst:
    """
    Performs the union of two wFSTs. If A transduces string `x` to `y` with weight `a`
    and `B` transduces string `w` to `v` with weight `b`, then their union transduces `x` to `y`
    with weight `a` and `w` to `v` with weight `b`.

    Examples:
    - Input Fst 1:

    ![union_in_1](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/union_in_1.svg?sanitize=true)

    - Input Fst 2:

    ![union_in_2](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/union_in_2.svg?sanitize=true)

    - Union:

    ![union_out](https://raw.githubusercontent.com/Garvys/rustfst-images-doc/master/images/union_out.svg?sanitize=true)

    Args:
        fst:
        other_fst:
    Returns:
         The resulting Fst.

    """

    ret_code = lib.fst_union(fst.ptr, other_fst.ptr)
    err_msg = "Error during union"
    check_ffi_error(ret_code, err_msg)

    return fst
コード例 #11
0
ファイル: __init__.py プロジェクト: Garvys/rustfst
def acceptor(astring: str,
             symbol_table: SymbolTable,
             weight: Optional[float] = None) -> VectorFst:
    """
    Creates an acceptor from a string.
    This function creates a FST which accepts its input with a fixed weight
    (defaulting to semiring One).
    Args:
      astring: The input string.
      weight: A Weight or weight string indicating the desired path weight. If
        omitted or null, the path weight is set to semiring One.
      symbol_table: SymbolTable to be used to encode the string.
    Returns:
      An FST acceptor.
    """
    if weight is None:
        weight = weight_one()

    acceptor_fst_ptr = ctypes.pointer(ctypes.c_void_p())
    ret_code = lib.utils_string_to_acceptor(
        astring.encode("utf-8"),
        symbol_table.ptr,
        ctypes.c_float(weight),
        ctypes.byref(acceptor_fst_ptr),
    )
    err_msg = "Error creating acceptor FST"
    check_ffi_error(ret_code, err_msg)
    return VectorFst(ptr=acceptor_fst_ptr)
コード例 #12
0
    def _find_index(self, key: int) -> str:
        key = ctypes.c_size_t(key)
        symbol = ctypes.c_void_p()
        ret_code = lib.symt_find_index(self.ptr, key, ctypes.byref(symbol))
        err_msg = "`find` failed"
        check_ffi_error(ret_code, err_msg)

        return ctypes.string_at(symbol).decode("utf8")
コード例 #13
0
ファイル: iterators.py プロジェクト: Garvys/rustfst
 def reset(self):
     """
     reset(self)
         Resets the iterator to the initial position.
     """
     ret_code = lib.trs_iterator_reset(self._ptr)
     err_msg = "`reset` failed"
     check_ffi_error(ret_code, err_msg)
コード例 #14
0
    def _find_symbol(self, symbol: str) -> int:
        symbol = symbol.encode("utf-8")
        index = ctypes.c_size_t()
        ret_code = lib.symt_find_symbol(self.ptr, symbol, ctypes.byref(index))
        err_msg = "`find` failed"
        check_ffi_error(ret_code, err_msg)

        return int(index.value)
コード例 #15
0
ファイル: iterators.py プロジェクト: Garvys/rustfst
 def __next__(self):
     """
     Advances the internal tr iterator.
     :return: None
     """
     ret_code = lib.mut_trs_iterator_next(self._ptr)
     err_msg = "`next` failed"
     check_ffi_error(ret_code, err_msg)
コード例 #16
0
    def relabel_tables(
        self,
        *,
        old_isymbols: Optional[SymbolTable] = None,
        new_isymbols: SymbolTable,
        attach_new_isymbols: bool = True,
        old_osymbols: Optional[SymbolTable] = None,
        new_osymbols: SymbolTable,
        attach_new_osymbols: bool = True,
    ) -> VectorFst:
        """
        Destructively relabel the Fst with new Symbol Tables.

        Relabelling refers to the operation where all the labels of an Fst are mapped to the equivalent labels
        of a new `SymbolTable`.
        If the Fst has a label `1` corresponding to the symbol "alpha" in the current symbol table and "alpha"
        is mapped to 4 in a new SymbolTable, then all the 1 are going to be mapped to 4.

        Args:
            old_isymbols: Input `SymbolTable` used to build the Fst. If `None`, uses the Input `SymbolTable` attached to the Fst.
            new_isymbols: New Input `SymbolTable` to use.
            attach_new_isymbols: Whether to attach the new Input `SymbolTable` to the Fst. If False, the resulting Fst won't contain any attached Input `SymbolTable`.
            old_osymbols: Output `SymbolTable` used to build the Fst. If `None`, uses the Output `SymbolTable` attached to the Fst
            new_osymbols: New Output `SymbolTable` to use.
            attach_new_osymbols: Whether to attach the new Output `SymbolTable` to the Fst. If False, the resulting Fst won't contain any attached Output `SymbolTable`.

        Returns:
            self

        """
        old_isymbols_ptr = old_isymbols.ptr if old_isymbols is not None else None
        old_osymbols_ptr = old_osymbols.ptr if old_osymbols is not None else None

        ret_code = lib.vec_fst_relabel_tables(
            self.ptr,
            old_isymbols_ptr,
            new_isymbols.ptr,
            ctypes.c_size_t(attach_new_isymbols),
            old_osymbols_ptr,
            new_osymbols.ptr,
            ctypes.c_size_t(attach_new_osymbols),
        )
        err_msg = "Relabel tables failed"
        check_ffi_error(ret_code, err_msg)

        # Necessary because the symts are cached on the python side.
        if attach_new_isymbols:
            self._input_symbols = new_isymbols
        else:
            self._input_symbols = None

        if attach_new_osymbols:
            self._output_symbols = new_osymbols
        else:
            self._output_symbols = None

        return self
コード例 #17
0
 def push(self, tr: Tr):
     """
     Add a new transition to the list.
     Args:
         tr: The transition to add.
     """
     exit_code = lib.trs_vec_push(self._ptr, tr.ptr)
     err_msg = "Something went wrong when adding new transition"
     check_ffi_error(exit_code, err_msg)
コード例 #18
0
ファイル: iterators.py プロジェクト: Garvys/rustfst
    def __init__(self, fst: Fst) -> StateIterator:
        self.ptr = fst  # reference fst to prolong its lifetime (prevent early gc)
        iter_ptr = ctypes.pointer(ctypes.c_void_p())

        ret_code = lib.state_iterator_new(fst.ptr, ctypes.byref(iter_ptr))
        err_msg = "`__init__` failed"
        check_ffi_error(ret_code, err_msg)

        self._ptr = iter_ptr
コード例 #19
0
ファイル: iterators.py プロジェクト: Garvys/rustfst
 def set_value(self, tr: Tr):
     """
     set_value(self, tr)
         Replace the current tr with a new tr.
         Args:
           tr: The tr to replace the current tr with.
     """
     ret_code = lib.mut_trs_iterator_set_value(self._ptr, tr.ptr)
     err_msg = "`set_value` failed"
     check_ffi_error(ret_code, err_msg)
コード例 #20
0
def optimize(fst: VectorFst):
    """
    Optimize an fst.
    Args:
        fst: Fst to optimize.
    """

    ret_code = lib.fst_optimize(fst.ptr)
    err_msg = "Error during optimize"
    check_ffi_error(ret_code, err_msg)
コード例 #21
0
 def add_table(self, syms: SymbolTable):
     """
     This method merges another symbol table into the current table. All key
     values will be offset by the current available key.
     Args:
       syms: A `SymbolTable` to be merged with the current table.
     """
     ret_code = lib.symt_add_table(self.ptr, syms.ptr)
     err_msg = "`add_table` failed"
     check_ffi_error(ret_code, err_msg)
コード例 #22
0
    def copy(self) -> ConstFst:
        """
        Returns :
            Deepcopy of the Fst.
        """
        cloned_fst = ctypes.c_size_t()
        ret_code = lib.const_fst_copy(self.ptr, ctypes.byref(cloned_fst))
        err_msg = "Error copying fst"
        check_ffi_error(ret_code, err_msg)

        return ConstFst(cloned_fst)
コード例 #23
0
ファイル: iterators.py プロジェクト: Garvys/rustfst
    def __init__(self, fst: Fst, state_id: int) -> MutableTrsIterator:
        self.ptr = fst  # reference fst to prolong its lifetime (prevent early gc)
        state_id = ctypes.c_size_t(state_id)
        iter_ptr = ctypes.pointer(ctypes.c_void_p())

        ret_code = lib.mut_trs_iterator_new(fst.ptr, state_id,
                                            ctypes.byref(iter_ptr))
        err_msg = "`__init__` failed"
        check_ffi_error(ret_code, err_msg)

        self._ptr = iter_ptr
コード例 #24
0
    def num_symbols(self) -> int:
        """
        Returns:
            The number of symbols in the symbol table.
        """
        num_symbols = ctypes.c_size_t()
        ret_code = lib.symt_num_symbols(self.ptr, ctypes.byref(num_symbols))
        err_msg = "`num_symbols` failed"
        check_ffi_error(ret_code, err_msg)

        return int(num_symbols.value)
コード例 #25
0
ファイル: tr_unique.py プロジェクト: Garvys/rustfst
def tr_unique(fst: VectorFst):
    """
    Keep a single instance of trs leaving the same state, going to the same state and
    with the same input labels, output labels and weight.
    Args:
        fst: Fst to modify
    """

    ret_code = lib.fst_tr_unique(fst.ptr)
    err_msg = "Error during tr_unique"
    check_ffi_error(ret_code, err_msg)
コード例 #26
0
    def copy(self) -> VectorFst:
        """
        Returns:
            A copy of the Fst.
        """
        cloned_fst = ctypes.pointer(ctypes.c_void_p())
        ret_code = lib.vec_fst_copy(self.ptr, ctypes.byref(cloned_fst))
        err_msg = "Error copying fst"
        check_ffi_error(ret_code, err_msg)

        return VectorFst(cloned_fst)
コード例 #27
0
def tr_sort(fst: VectorFst, ilabel_cmp: bool):
    """
    tr_sort(fst)
    sort fst trs according to their ilabel or olabel
    :param fst: Fst
    :param ilabel_cmp: bool
    """

    ret_code = lib.fst_tr_sort(fst.ptr, ctypes.c_bool(ilabel_cmp))
    err_msg = "Error during tr_sort"
    check_ffi_error(ret_code, err_msg)
コード例 #28
0
 def __init__(self, ptr=None) -> Trs:
     """
     Create an empty list of transitions.
     """
     if ptr is None:
         self._ptr = c_void_p()
         exit_code = lib.trs_vec_new(byref(self._ptr))
         err_msg = "Something went wrong when creating the Trs struct"
         check_ffi_error(exit_code, err_msg)
     else:
         self._ptr = ptr
コード例 #29
0
ファイル: weight.py プロジェクト: Garvys/rustfst
def weight_one() -> float:
    """
    Compute One() in the Tropical Semiring.
    Returns:
        Float value corresponding to One() in the Tropical Semiring.
    """
    weight = ctypes.c_float()
    ret_code = lib.fst_weight_one(ctypes.byref(weight))
    err_msg = "weight_one failed"
    check_ffi_error(ret_code, err_msg)
    return float(weight.value)
コード例 #30
0
    def copy(self) -> SymbolTable:
        """
        Returns:
            A mutable copy of the `SymbolTable`.
        """
        clone = ctypes.pointer(ctypes.c_void_p())

        ret_code = lib.symt_copy(self.ptr, ctypes.byref(clone))
        err_msg = "`copy` failed."
        check_ffi_error(ret_code, err_msg)

        return SymbolTable(ptr=clone)