예제 #1
0
파일: fsa.py 프로젝트: stormytrail/k2
    def __init__(self,
                 tensor: torch.Tensor,
                 aux_labels: Optional[torch.Tensor] = None) -> None:
        '''Build an Fsa from a tensor with optional aux_labels.

        It is useful when loading an Fsa from file.

        Args:
          tensor:
            A torch tensor of dtype `torch.int32` with 4 columns.
            Each row represents an arc. Column 0 is the src_state,
            column 1 the dest_state, column 2 the label, and column
            3 the score.

            Caution:
              Scores are floats and their binary pattern is
              **reinterpreted** as integers and saved in a tensor
              of dtype `torch.int32`.

          aux_labels:
            Optional. If not None, it associates an aux_label with every arc,
            so it has as many rows as `tensor`. It is a 1-D tensor of dtype
            `torch.int32`.

        Returns:
          An instance of Fsa.
        '''
        self._init_internal()
        self.arcs: RaggedArc = _fsa_from_tensor(tensor)
        self._init_properties()
        self._tensor_attr['scores'] = _as_float(self.arcs.values()[:, -1])
        if aux_labels is not None:
            self.aux_labels = aux_labels.to(torch.int32)
예제 #2
0
파일: fsa.py 프로젝트: johnjosephmorgan/k2
    def from_tensor(cls,
                    tensor: torch.Tensor,
                    aux_labels: Optional[torch.Tensor] = None) -> 'Fsa':
        '''Build an Fsa from a tensor with optional aux_labels.

        It is useful when loading an Fsa from file.

        Args:
          tensor:
            A torch tensor of dtype `torch.int32` with 4 columns.
            Each row represents an arc. Column 0 is the src_state,
            column 1 the dest_state, column 2 the label, and column
            3 the score.

            Caution:
              Scores are floats and their binary pattern is
              **reinterpreted** as integers and saved in a tensor
              of dtype `torch.int32`.

          aux_labels:
            Optional. If not None, it associates an aux_label with every arc,
            so it has as many rows as `tensor`. It is a 1-D tensor of dtype
            `torch.int32`.

        Returns:
          An instance of Fsa.
        '''
        ans = cls.__new__(cls)
        super(Fsa, ans).__init__()
        ans._fsa = _fsa_from_tensor(tensor)
        ans._aux_labels = aux_labels
        return ans
예제 #3
0
파일: fsa.py 프로젝트: ts0923/k2
    def __init__(self,
                 arcs: Union[torch.Tensor, RaggedArc],
                 aux_labels: Optional[torch.Tensor] = None,
                 properties=None) -> None:
        '''Build an Fsa from a tensor with optional aux_labels.

        It is useful when loading an Fsa from file.

        Args:
          tensor:
            A torch tensor of dtype `torch.int32` with 4 columns.
            Each row represents an arc. Column 0 is the src_state,
            column 1 the dest_state, column 2 the label, and column
            3 the score.

            Caution:
              Scores are floats and their binary pattern is
              **reinterpreted** as integers and saved in a tensor
              of dtype `torch.int32`.

          aux_labels:
            Optional. If not None, it associates an aux_label with every arc,
            so it has as many rows as `tensor`. It is a 1-D tensor of dtype
            `torch.int32`.

          properties:
            Tensor properties if known (should only be provided by
            internal code, as they are not checked; intended for use
            by Fsa.clone())

        Returns:
          An instance of Fsa.
        '''
        if isinstance(arcs, torch.Tensor):
            arcs: RaggedArc = _fsa_from_tensor(arcs)
        assert isinstance(arcs, RaggedArc)

        # Accessing self.__dict__ bypasses __setattr__.
        self.__dict__['arcs'] = arcs
        self.__dict__['_properties'] = properties

        # - `_tensor_attr`
        #     It saves attribute values of type torch.Tensor. `shape[0]` of
        #     attribute values have to be equal to the number of arcs
        #     in the FSA.  There are a couple of standard ones, 'aux_labels'
        #     (present for transducers), and 'scores'.
        #
        # - `_non_tensor_attr`
        #     It saves non-tensor attributes, e.g., :class:`SymbolTable`.
        #
        # - `_cache`
        #     It contains tensors for autograd. Users should NOT manipulate it.
        #     The dict is filled in automagically.
        #
        # The `_cache` dict contains the following attributes:
        #
        #  - `state_batches`:
        #           returned by :func:`_k2._get_state_batches`
        #  - `dest_states`:
        #           returned by :func:`_k2._get_dest_states`
        #  - `incoming_arcs`:
        #           returned by :func:`_k2._get_incoming_arcs`
        #  - `entering_arc_batches`:
        #           returned by :func:`_k2._get_entering_arc_index_batches`
        #  - `leaving_arc_batches`:
        #           returned by :func:`_k2._get_leaving_arc_index_batches`
        #  - `forward_scores_tropical`:
        #           returned by :func:`_k2._get_forward_scores_float`
        #           with `log_semiring=False`
        #  - `forward_scores_log`:
        #           returned by :func:`_k2._get_forward_scores_float` or
        #           :func:`_get_forward_scores_double` with `log_semiring=True`
        #  - `tot_scores_tropical`:
        #           returned by :func:`_k2._get_tot_scores_float` or
        #           :func:`_k2._get_tot_scores_double` with
        #           `forward_scores_tropical`.
        #  - `tot_scores_log`:
        #           returned by :func:`_k2._get_tot_scores_float` or
        #           :func:`_k2._get_tot_scores_double` with
        #           `forward_scores_log`.
        #  - `backward_scores_tropical`:
        #           returned by :func:`_k2._get_backward_scores_float` or
        #           :func:`_k2._get_backward_scores_double` with
        #           `log_semiring=False`
        #  - `backward_scores_log_semiring`:
        #           returned by :func:`_k2._get_backward_scores_float` or
        #           :func:`_k2._get_backward_scores_double` with
        #           `log_semiring=True`
        #  - `entering_arcs`:
        #           returned by :func:`_k2._get_forward_scores_float` or
        #           :func:`_get_forward_scores_double` with `log_semiring=False`

        for name in ['_tensor_attr', '_non_tensor_attr', '_cache']:
            self.__dict__[name] = dict()

        self._tensor_attr['scores'] = _as_float(self.arcs.values()[:, -1])
        if aux_labels is not None:
            self.aux_labels = aux_labels.to(torch.int32)
        # Access the properties field (it's a @property, i.e. it has a
        # getter) which sets up the properties and also checks that
        # the FSA is valid.
        _ = self.properties