def index_tensor( src: torch.Tensor, indexes: Union[torch.Tensor, _k2.RaggedInt] ) -> Union[torch.Tensor, _k2.RaggedInt]: # noqa '''Indexing a 1-D tensor with a 1-D tensor a ragged tensor. Args: src: Source 1-D tensor to index, must have `src.dtype == torch.int32` or `src.dtype == torch.float32`. indexes: It satisfies -1 <= indexes.values()[i] < src.numel(). - If it's a tensor, its values will be interpreted as indexes into `src`; if indexes.values()[i] is -1, then ans[i] is 0. - If it's a ragged tensor, `indexes.values()` will be interpreted as indexes into `src`. If src.dtype is torch.int32, it returns a _k2.RaggedInt; if src.dtype is torch.float32, it performs an extra sum-per-sublist operation and returns 1-D torch.Tensor. Returns: Returns a tensor or a ragged tensor (depending on the type of `indexes`) ''' if isinstance(indexes, torch.Tensor): return index_select(src, indexes) else: assert isinstance(indexes, k2.RaggedInt) if src.dtype == torch.int32: return _k2.index(src, indexes) else: assert src.dtype == torch.float32 return index_and_sum(src, indexes)
def index_ragged( src: _k2.RaggedInt, indexes: Union[torch.Tensor, _k2.RaggedInt]) -> _k2.RaggedInt: # noqa '''Indexing ragged tensor with a 1-D tensor or a ragged tensor. Args: src: Source ragged tensor to index; must have num_axes() == 2. indexes: It can be a tensor or a ragged tensor. If it's a tensor, it must be a 1-D tensor and `indexes.dtype == torch.int32`. Values in it will be interpreted as indexes into axis 0 of `src`, i.e. -1 <= indexes[i] < src.dim0(). If indexes[i] is -1, then the i-th value of ans is empty. If it's a ragged tensor, `indexes.values` will be interpreted as indexes into axis 0 of `src`, i.e. 0 <= indexes.values[i] < src.dim0(); Must have num_axes() == 2. Returns: Return the indexed ragged tensor with ans.num_axes() == 2 - If `indexes` is a 1-D tensor, then ans.dim0() == indexes.numel(). - If `indexes` is a ragged tensor, then ans.dim0() = indexes.dim0(). ''' if isinstance(indexes, torch.Tensor): ans, _ = ragged_index(src, indexes) return ans else: return _k2.index(src, indexes)
def index( src: _k2.RaggedArc, indexes: torch.Tensor, need_value_indexes: bool = True ) -> Tuple[_k2.RaggedArc, Optional[torch.Tensor]]: # noqa '''Indexing operation on ragged tensor, returns src[indexes], where the elements of `indexes` are interpreted as indexes into axis 0 of `src`. Caution: `indexes` is a 1-D tensor and `indexes.dtype == torch.int32`. Args: src: Source ragged tensor to index. indexes: Array of indexes, which will be interpreted as indexes into axis 0 of `src`, i.e. with 0 <= indexes[i] < src.dim0(). need_value_indexes: If true, it will return a torch.Tensor containing the indexes into `src.values()` that `ans.values()` has, as in `ans.values() = src.values()[value_indexes]`. Returns: Return a tuple containing: - `ans` of type `_k2.RaggedArc`. - None if `need_value_indexes` is False; a 1-D torch.tensor of dtype `torch.int32` containing the indexes into `src.values()` that `ans.values()` has. ''' ans, value_indexes = _k2.index(src=src, indexes=indexes, need_value_indexes=need_value_indexes) return ans, value_indexes
def index( src: Union[_k2.RaggedArc, _k2.RaggedInt, _k2.RaggedShape], indexes: torch.Tensor, need_value_indexes: bool = True, axis: int = 0 ) -> Tuple[Union[_k2.RaggedArc, _k2.RaggedInt, _k2.RaggedShape], # noqa Optional[torch.Tensor]]: # noqa '''Indexing operation on ragged tensor, returns src[indexes], where the elements of `indexes` are interpreted as indexes into axis `axis` of `src`. Caution: `indexes` is a 1-D tensor and `indexes.dtype == torch.int32`. Args: src: Source ragged tensor or ragged shape to index. axis: The axis to be indexed. Must satisfy 0 <= axis < src.num_axes() indexes: Array of indexes, which will be interpreted as indexes into axis `axis` of `src`, i.e. with 0 <= indexes[i] < src.tot_size(axis). Note that if `axis` is 0, then -1 is also a valid entry in `index`. need_value_indexes: If true, it will return a torch.Tensor containing the indexes into `src.values()` that `ans.values()` has, as in `ans.values() = src.values()[value_indexes]`. Returns: Return a tuple containing: - `ans` of type `_k2.RaggedArc` or `_k2.RaggedInt` (same as the type of `src`). - None if `need_value_indexes` is False; a 1-D torch.tensor of dtype `torch.int32` containing the indexes into `src.values()` that `ans.values()` has. ''' ans, value_indexes = _k2.index(src=src, axis=axis, indexes=indexes, need_value_indexes=need_value_indexes) return ans, value_indexes
def index_tensor( src: torch.Tensor, indexes: Union[torch.Tensor, _k2.RaggedInt] ) -> Union[torch.Tensor, _k2.RaggedInt]: # noqa '''Indexing a 1-D tensor with a 1-D tensor a ragged tensor. Args: src: Source 1-D tensor to index, must have `src.dtype == torch.int32` indexes: If it's a ragged tensor, `indexes.values` will be interpreted as indexes into `src`. i.e. 0 <= indexes.values[i] < src.numel(); If it's a tensor, its values will be interpreted as indexes into `src`. Returns: Returns a tensor or a ragged tensor (depending on the type of `indexes`) ''' if isinstance(indexes, torch.Tensor): return index_select(src, indexes) else: # TODO(haowen): it does not autograd now. return _k2.index(src, indexes)
def index_fsa(src: Fsa, indexes: torch.Tensor) -> Fsa: '''Select a list of FSAs from `src` with a 1-D tensor. Args: src: An FsaVec. indexes: A 1-D `torch.Tensor` of dtype `torch.int32` containing the ids of FSAs to select. Returns: Return an FsaVec containing only those FSAs specified by `indexes`. ''' # TODO: export it to k2 ragged_arc, value_indexes = _k2.index(src.arcs, axis=0, indexes=indexes, need_value_indexes=True) out_fsa = Fsa(ragged_arc) for name, value in src.named_tensor_attr(): if isinstance(value, torch.Tensor): setattr(out_fsa, name, k2.ops.index_select(value, value_indexes)) else: assert isinstance(value, k2.RaggedTensor) assert value.dtype == torch.int32 ragged_value, _ = value.index(value_indexes, axis=0, need_value_indexes=False) setattr(out_fsa, name, ragged_value) for name, value in src.named_non_tensor_attr(): setattr(out_fsa, name, value) return out_fsa