def intersect_device( a_fsas: Fsa, b_fsas: Fsa, b_to_a_map: torch.Tensor, sorted_match_a: bool = False, ret_arc_maps: bool = False ) -> Union[Fsa, Tuple[Fsa, torch.Tensor, torch.Tensor]]: # noqa '''Compute the intersection of two FsaVecs treating epsilons as real, normal symbols. This function supports both CPU and GPU. But it is very slow on CPU. That's why this function name ends with `_device`. It is intended for GPU. See :func:`k2.intersect` which is a more general interface (it will call the same underlying code, IntersectDevice(), if the inputs are on GPU and a_fsas is arc-sorted). Caution: Epsilons are treated as real, normal symbols. Hint: The two inputs do not need to be arc-sorted. Refer to :func:`k2.intersect` for how we assign the attributes of the output FsaVec. Args: a_fsas: An FsaVec (must have 3 axes, i.e., `len(a_fsas.shape) == 3`. b_fsas: An FsaVec (must have 3 axes) on the same device as `a_fsas`. b_to_a_map: A 1-D torch.Tensor with dtype torch.int32 on the same device as `a_fsas`. Map from FSA-id in `b_fsas` to the corresponding FSA-id in `a_fsas` that we want to compose it with. E.g. might be an identity map, or all-to-zero, or something the user chooses. Requires - `b_to_a_map.shape[0] == b_fsas.shape[0]` - `0 <= b_to_a_map[i] < a_fsas.shape[0]` ret_arc_maps: If False, return the resulting Fsa. If True, return a tuple containing three entries: - the resulting Fsa - a_arc_map, a 1-D torch.Tensor with dtype torch.int32. a_arc_map[i] is the arc index in a_fsas that corresponds to the i-th arc in the resulting Fsa. a_arc_map[i] is -1 if the i-th arc in the resulting Fsa has no corresponding arc in a_fsas. - b_arc_map, a 1-D torch.Tensor with dtype torch.int32. b_arc_map[i] is the arc index in b_fsas that corresponds to the i-th arc in the resulting Fsa. b_arc_map[i] is -1 if the i-th arc in the resulting Fsa has no corresponding arc in b_fsas. Returns: If ret_arc_maps is False, return intersected FsaVec; will satisfy `ans.shape == b_fsas.shape`. If ret_arc_maps is True, it returns additionally two arc maps: a_arc_map and b_arc_map. ''' need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect_device( a_fsas.arcs, a_fsas.properties, b_fsas.arcs, b_fsas.properties, b_to_a_map, need_arc_map, sorted_match_a) out_fsas = Fsa(ragged_arc) _propagate_aux_labels_binary_function(a_fsas, b_fsas, a_arc_map, b_arc_map, out_fsas) if ret_arc_maps: return out_fsas, a_arc_map, b_arc_map else: return out_fsas
def intersect_device(a_fsas: Fsa, b_fsas: Fsa, b_to_a_map: torch.Tensor) -> Fsa: '''Compute the intersection of two FSAs treating epsilons as real, normal symbols. This function supports both CPU and GPU. But it is very slow on CPU. That's why this function name ends with `_device`. It is intended for GPU. See :func:`k2.intersect` for intersecting two FSAs on CPU. Caution: Epsilons are treated as real, normal symbols. Hint: The two inputs do not need to be arc-sorted. Refer to :func:`k2.intersect` for how we assign the attributes of the output FsaVec. Args: a_fsas: An FsaVec (must have 3 axes, i.e., `len(a_fsas.shape) == 3`. b_fsas: An FsaVec (must have 3 axes) on the same device as `a_fsas`. b_to_a_map: A 1-D torch.Tensor with dtype torch.int32 on the same device as `a_fsas`. Map from FSA-id in `b_fsas` to the corresponding FSA-id in `a_fsas` that we want to compose it with. E.g. might be an identity map, or all-to-zero, or something the user chooses. Requires - `b_to_a_map.shape[0] == b_fsas.shape[0]` - `0 <= b_to_a_map[i] < a_fsas.shape[0]` Returns: Returns composed FsaVec; will satisfy `ans.shape == b_fsas.shape`. ''' need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect_device( a_fsas.arcs, a_fsas.properties, b_fsas.arcs, b_fsas.properties, b_to_a_map, need_arc_map) out_fsas = Fsa(ragged_arc) for name, a_value in a_fsas.named_tensor_attr(): if hasattr(b_fsas, name): # Both a_fsas and b_fsas have this attribute. # We only support attributes with dtype `torch.float32`. # Other kinds of attributes are discarded. if a_value.dtype != torch.float32: continue b_value = getattr(b_fsas, name) assert b_value.dtype == torch.float32 value = index_select(a_value, a_arc_map) \ + index_select(b_value, b_arc_map) setattr(out_fsas, name, value) else: # only a_fsas has this attribute, copy it via arc_map value = index(a_value, a_arc_map) setattr(out_fsas, name, value) # now copy tensor attributes that are in b_fsas but are not in a_fsas for name, b_value in b_fsas.named_tensor_attr(): if not hasattr(out_fsas, name): value = index(b_value, b_arc_map) setattr(out_fsas, name, value) for name, a_value in a_fsas.named_non_tensor_attr(): setattr(out_fsas, name, a_value) for name, b_value in b_fsas.named_non_tensor_attr(): if not hasattr(out_fsas, name): setattr(out_fsas, name, b_value) return out_fsas