def __init__(self, tensor: torch.Tensor, aux_labels: Optional[torch.Tensor] = None) -> None: '''Build an Fsa from a tensor with optional aux_labels. It is useful when loading an Fsa from file. Args: tensor: A torch tensor of dtype `torch.int32` with 4 columns. Each row represents an arc. Column 0 is the src_state, column 1 the dest_state, column 2 the label, and column 3 the score. Caution: Scores are floats and their binary pattern is **reinterpreted** as integers and saved in a tensor of dtype `torch.int32`. aux_labels: Optional. If not None, it associates an aux_label with every arc, so it has as many rows as `tensor`. It is a 1-D tensor of dtype `torch.int32`. Returns: An instance of Fsa. ''' self._init_internal() self.arcs: RaggedArc = _fsa_from_tensor(tensor) self._init_properties() self._tensor_attr['scores'] = _as_float(self.arcs.values()[:, -1]) if aux_labels is not None: self.aux_labels = aux_labels.to(torch.int32)
def from_tensor(cls, tensor: torch.Tensor, aux_labels: Optional[torch.Tensor] = None) -> 'Fsa': '''Build an Fsa from a tensor with optional aux_labels. It is useful when loading an Fsa from file. Args: tensor: A torch tensor of dtype `torch.int32` with 4 columns. Each row represents an arc. Column 0 is the src_state, column 1 the dest_state, column 2 the label, and column 3 the score. Caution: Scores are floats and their binary pattern is **reinterpreted** as integers and saved in a tensor of dtype `torch.int32`. aux_labels: Optional. If not None, it associates an aux_label with every arc, so it has as many rows as `tensor`. It is a 1-D tensor of dtype `torch.int32`. Returns: An instance of Fsa. ''' ans = cls.__new__(cls) super(Fsa, ans).__init__() ans._fsa = _fsa_from_tensor(tensor) ans._aux_labels = aux_labels return ans
def __init__(self, arcs: Union[torch.Tensor, RaggedArc], aux_labels: Optional[torch.Tensor] = None, properties=None) -> None: '''Build an Fsa from a tensor with optional aux_labels. It is useful when loading an Fsa from file. Args: tensor: A torch tensor of dtype `torch.int32` with 4 columns. Each row represents an arc. Column 0 is the src_state, column 1 the dest_state, column 2 the label, and column 3 the score. Caution: Scores are floats and their binary pattern is **reinterpreted** as integers and saved in a tensor of dtype `torch.int32`. aux_labels: Optional. If not None, it associates an aux_label with every arc, so it has as many rows as `tensor`. It is a 1-D tensor of dtype `torch.int32`. properties: Tensor properties if known (should only be provided by internal code, as they are not checked; intended for use by Fsa.clone()) Returns: An instance of Fsa. ''' if isinstance(arcs, torch.Tensor): arcs: RaggedArc = _fsa_from_tensor(arcs) assert isinstance(arcs, RaggedArc) # Accessing self.__dict__ bypasses __setattr__. self.__dict__['arcs'] = arcs self.__dict__['_properties'] = properties # - `_tensor_attr` # It saves attribute values of type torch.Tensor. `shape[0]` of # attribute values have to be equal to the number of arcs # in the FSA. There are a couple of standard ones, 'aux_labels' # (present for transducers), and 'scores'. # # - `_non_tensor_attr` # It saves non-tensor attributes, e.g., :class:`SymbolTable`. # # - `_cache` # It contains tensors for autograd. Users should NOT manipulate it. # The dict is filled in automagically. # # The `_cache` dict contains the following attributes: # # - `state_batches`: # returned by :func:`_k2._get_state_batches` # - `dest_states`: # returned by :func:`_k2._get_dest_states` # - `incoming_arcs`: # returned by :func:`_k2._get_incoming_arcs` # - `entering_arc_batches`: # returned by :func:`_k2._get_entering_arc_index_batches` # - `leaving_arc_batches`: # returned by :func:`_k2._get_leaving_arc_index_batches` # - `forward_scores_tropical`: # returned by :func:`_k2._get_forward_scores_float` # with `log_semiring=False` # - `forward_scores_log`: # returned by :func:`_k2._get_forward_scores_float` or # :func:`_get_forward_scores_double` with `log_semiring=True` # - `tot_scores_tropical`: # returned by :func:`_k2._get_tot_scores_float` or # :func:`_k2._get_tot_scores_double` with # `forward_scores_tropical`. # - `tot_scores_log`: # returned by :func:`_k2._get_tot_scores_float` or # :func:`_k2._get_tot_scores_double` with # `forward_scores_log`. # - `backward_scores_tropical`: # returned by :func:`_k2._get_backward_scores_float` or # :func:`_k2._get_backward_scores_double` with # `log_semiring=False` # - `backward_scores_log_semiring`: # returned by :func:`_k2._get_backward_scores_float` or # :func:`_k2._get_backward_scores_double` with # `log_semiring=True` # - `entering_arcs`: # returned by :func:`_k2._get_forward_scores_float` or # :func:`_get_forward_scores_double` with `log_semiring=False` for name in ['_tensor_attr', '_non_tensor_attr', '_cache']: self.__dict__[name] = dict() self._tensor_attr['scores'] = _as_float(self.arcs.values()[:, -1]) if aux_labels is not None: self.aux_labels = aux_labels.to(torch.int32) # Access the properties field (it's a @property, i.e. it has a # getter) which sets up the properties and also checks that # the FSA is valid. _ = self.properties