def intersect(a_fsa: Fsa, b_fsa: Fsa) -> Fsa: '''Compute the intersection of two FSAs on CPU. Args: a_fsa: The first input FSA on CPU. It can be either a single FSA or an FsaVec. b_fsa: The second input FSA on CPU. it can be either a single FSA or an FsaVec. Caution: The two input FSAs MUST be arc sorted. Caution: The rules for assigning the attributes of the output Fsa are as follows: - (1) For attributes where only one source (a_fsa or b_fsa) has that attribute: Copy via arc_map, or use zero if arc_map has -1. This rule works for both floating point and integer attributes. - (2) For attributes where both sources (a_fsa and b_fsa) have that attribute: For floating point attributes: sum via arc_maps, or use zero if arc_map has -1. For integer attributes, it's not supported for now (the attributes will be discarded and will not be kept in the output FSA). Returns: The result of intersecting a_fsa and b_fsa. len(out_fsa.shape) is 2 if and only if the two input FSAs are single FSAs; otherwise, len(out_fsa.shape) is 3. ''' assert a_fsa.is_cpu() assert b_fsa.is_cpu() assert a_fsa.properties & fsa_properties.ARC_SORTED != 0 assert b_fsa.properties & fsa_properties.ARC_SORTED != 0 treat_epsilons_specially = True need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect( a_fsa.arcs, a_fsa.properties, b_fsa.arcs, b_fsa.properties, treat_epsilons_specially, need_arc_map) out_fsa = Fsa(ragged_arc) for name, a_value in a_fsa.named_tensor_attr(): if hasattr(b_fsa, name): # Both a_fsa and b_fsa have this attribute. # We only support attributes with dtype `torch.float32`. # Other kinds of attributes are discarded. if a_value.dtype != torch.float32: continue b_value = getattr(b_fsa, name) assert b_value.dtype == torch.float32 value = index_select(a_value, a_arc_map) \ + index_select(b_value, b_arc_map) setattr(out_fsa, name, value) else: # only a_fsa has this attribute, copy it via arc_map value = index_attr(a_value, a_arc_map) setattr(out_fsa, name, value) # now copy tensor attributes that are in b_fsa but are not in a_fsa for name, b_value in b_fsa.named_tensor_attr(): if not hasattr(out_fsa, name): value = index_attr(b_value, b_arc_map) setattr(out_fsa, name, value) for name, a_value in a_fsa.named_non_tensor_attr(): setattr(out_fsa, name, a_value) for name, b_value in b_fsa.named_non_tensor_attr(): if not hasattr(out_fsa, name): setattr(out_fsa, name, b_value) return out_fsa
def intersect( a_fsa: Fsa, b_fsa: Fsa, treat_epsilons_specially: bool = True, ret_arc_maps: bool = False ) -> Union[Fsa, Tuple[Fsa, torch.Tensor, torch.Tensor]]: # noqa '''Compute the intersection of two FSAs. When `treat_epsilons_specially` is True, this function works only on CPU. When `treat_epsilons_specially` is False and both `a_fsa` and `b_fsa` are on GPU, then this function works on GPU; in this case, the two input FSAs do not need to be arc sorted. Args: a_fsa: The first input FSA. It can be either a single FSA or an FsaVec. b_fsa: The second input FSA. it can be either a single FSA or an FsaVec. treat_epsilons_specially: If True, epsilons will be treated as epsilon, meaning epsilon arcs can match with an implicit epsilon self-loop. If False, epsilons will be treated as real, normal symbols (to have them treated as epsilons in this case you may have to add epsilon self-loops to whichever of the inputs is naturally epsilon-free). ret_arc_maps: If False, return the resulting Fsa. If True, return a tuple containing three entries: - the resulting Fsa - a_arc_map, a 1-D torch.Tensor with dtype torch.int32. a_arc_map[i] is the arc index in a_fsa that corresponds to the i-th arc in the resulting Fsa. a_arc_map[i] is -1 if the i-th arc in the resulting Fsa has no corresponding arc in a_fsa. - b_arc_map, a 1-D torch.Tensor with dtype torch.int32. b_arc_map[i] is the arc index in b_fsa that corresponds to the i-th arc in the resulting Fsa. b_arc_map[i] is -1 if the i-th arc in the resulting Fsa has no corresponding arc in b_fsa. Caution: The two input FSAs MUST be arc sorted if `treat_epsilons_specially` is True. Caution: The rules for assigning the attributes of the output Fsa are as follows: - (1) For attributes where only one source (a_fsa or b_fsa) has that attribute: Copy via arc_map, or use zero if arc_map has -1. This rule works for both floating point and integer attributes. - (2) For attributes where both sources (a_fsa and b_fsa) have that attribute: For floating point attributes: sum via arc_maps, or use zero if arc_map has -1. For integer attributes, it's not supported for now (the attributes will be discarded and will not be kept in the output FSA). Returns: If ret_arc_maps is False, return the result of intersecting a_fsa and b_fsa. len(out_fsa.shape) is 2 if and only if the two input FSAs are single FSAs; otherwise, len(out_fsa.shape) is 3. If ret_arc_maps is True, it returns additionally two arc_maps: a_arc_map and b_arc_map. ''' if a_fsa.is_cpu() or b_fsa.is_cpu(): assert a_fsa.properties & fsa_properties.ARC_SORTED != 0 assert b_fsa.properties & fsa_properties.ARC_SORTED != 0 need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect( a_fsa.arcs, a_fsa.properties, b_fsa.arcs, b_fsa.properties, treat_epsilons_specially, need_arc_map) out_fsa = Fsa(ragged_arc) _propagate_aux_labels_binary_function(a_fsa, b_fsa, a_arc_map, b_arc_map, out_fsa) if ret_arc_maps: return out_fsa, a_arc_map, b_arc_map else: return out_fsa
def compose(a_fsa: Fsa, b_fsa: Fsa, treat_epsilons_specially: bool = True, inner_labels: str = None) -> Fsa: '''Compute the composition of two FSAs (currently on CPU). When `treat_epsilons_specially` is True, this function works only on CPU. When `treat_epsilons_specially` is False and both `a_fsa` and `b_fsa` are on GPU, then this function works on GPU; in this case, the two input FSAs do not need to be arc sorted. Note: `a_fsa.aux_labels` is required to be defined and it can be either a `torch.Tensor` or a ragged tensor of type `k2.RaggedInt`. If it is a ragged tensor, then it requires that a_fsa.requires_grad is False. For both FSAs, the `aux_labels` attribute is interpreted as output labels, (olabels), and the composition involves matching the olabels of a_fsa with the ilabels of b_fsa. This is implemented by intersecting the inverse of a_fsa (a_fsa_inv) with b_fsa, then replacing the ilabels of the result with the original ilabels on a_fsa which are now the aux_labels of a_fsa_inv. If `b_fsa.aux_labels` is not defined, `b_fsa` is treated as an acceptor (as in OpenFST), i.e. its olabels and ilabels are assumed to be the same. Refer to :func:`k2.intersect` for how we assign the attributes of the output FSA. Args: a_fsa: The first input FSA. It can be either a single FSA or an FsaVec. b_fsa: The second input FSA. it can be either a single FSA or an FsaVec. treat_epsilons_specially: If True, epsilons will be treated as epsilon, meaning epsilon arcs can match with an implicit epsilon self-loop. If False, epsilons will be treated as real, normal symbols (to have them treated as epsilons in this case you may have to add epsilon self-loops to whichever of the inputs is naturally epsilon-free). inner_labels: If specified (and if a_fsa has `aux_labels`), the labels that we matched on, which would normally be discarded, will instead be copied to this attribute name. Caution: `b_fsa` has to be arc sorted if the function runs on CPU. Returns: The result of composing a_fsa and b_fsa. `len(out_fsa.shape)` is 2 if and only if the two input FSAs are single FSAs; otherwise, `len(out_fsa.shape)` is 3. ''' assert hasattr(a_fsa, 'aux_labels') if a_fsa.requires_grad: assert isinstance(a_fsa.aux_labels, torch.Tensor) a_fsa_inv = a_fsa.invert() else: # k2.invert() does not support autograd. # The current use case is for decoding, which does not need autograd. # We may extend it to support autograd if needed. a_fsa_inv = invert(a_fsa) if treat_epsilons_specially is True or a_fsa_inv.is_cpu(): # the GPU version does not need to sort the input FSA a_fsa_inv = arc_sort(a_fsa_inv) if treat_epsilons_specially is True or b_fsa.is_cpu(): # the GPU version does not need to sort the input FSA assert b_fsa.properties & fsa_properties.ARC_SORTED != 0 need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect( a_fsa_inv.arcs, a_fsa_inv.properties, b_fsa.arcs, b_fsa.properties, treat_epsilons_specially, need_arc_map) out_fsa = Fsa(ragged_arc) if inner_labels is not None: # out_fsa.`inner_labels` = out_fsa.labels.clone() # need a clone here since `Fsa.labels` is a reference setattr(out_fsa, inner_labels, out_fsa.labels.clone()) if hasattr(b_fsa, 'aux_labels'): out_fsa.aux_labels = index(b_fsa.aux_labels, b_arc_map) else: # need a clone here since `Fsa.labels` is a reference out_fsa.aux_labels = out_fsa.labels.clone() if isinstance(a_fsa_inv.aux_labels, torch.Tensor): out_fsa.labels = index(a_fsa_inv.aux_labels, a_arc_map) else: assert isinstance(a_fsa_inv.aux_labels, k2.RaggedInt) # Refer to the following URLs for an example: # a_fsa: https://git.io/Jqbob # b_fsa: https://git.io/JqbKL # out_fsa: https://git.io/JqbK3 out_fsa.labels = out_fsa.aux_labels out_fsa.aux_labels = index(a_fsa_inv.aux_labels, a_arc_map) out_fsa = invert(out_fsa) for name, a_value in a_fsa_inv.named_tensor_attr(): if name in ('aux_labels', inner_labels): continue if hasattr(b_fsa, name): # Both a_fsa and b_fsa have this attribute. # We only support attributes with dtype `torch.float32`. # Other kinds of attributes are discarded. if a_value.dtype != torch.float32: continue b_value = getattr(b_fsa, name) assert b_value.dtype == torch.float32 # The following will actually overwrite `scores` with the same # value it had before; but this enables the autograd to work since # we do it using torch mechanisms. value = index_select(a_value, a_arc_map) + index_select( b_value, b_arc_map) setattr(out_fsa, name, value) else: # only a_fsa has this attribute, copy it via arc_map value = index(a_value, a_arc_map) setattr(out_fsa, name, value) # now copy tensor attributes that are in b_fsa but are not in a_fsa for name, b_value in b_fsa.named_tensor_attr(): if name in ('aux_labels', inner_labels): continue if not hasattr(out_fsa, name): value = index(b_value, b_arc_map) setattr(out_fsa, name, value) for name, a_value in a_fsa_inv.named_non_tensor_attr(): if name == 'symbols': continue if name == 'aux_symbols': setattr(out_fsa, 'symbols', a_value) else: setattr(out_fsa, name, a_value) for name, b_value in b_fsa.named_non_tensor_attr(): if name == 'symbols' and not hasattr(b_fsa, 'aux_labels'): setattr(out_fsa, 'aux_symbols', b_value) elif not hasattr(out_fsa, name): setattr(out_fsa, name, b_value) return out_fsa
def intersect(a_fsa: Fsa, b_fsa: Fsa, treat_epsilons_specially: bool = True) -> Fsa: '''Compute the intersection of two FSAs. When `treat_epsilons_specially` is True, this function works only on CPU. When `treat_epsilons_specially` is False and both `a_fsa` and `b_fsa` are on GPU, then this function works on GPU; in this case, the two input FSAs do not need to be arc sorted. Args: a_fsa: The first input FSA. It can be either a single FSA or an FsaVec. b_fsa: The second input FSA. it can be either a single FSA or an FsaVec. treat_epsilons_specially: If True, epsilons will be treated as epsilon, meaning epsilon arcs can match with an implicit epsilon self-loop. If False, epsilons will be treated as real, normal symbols (to have them treated as epsilons in this case you may have to add epsilon self-loops to whichever of the inputs is naturally epsilon-free). Caution: The two input FSAs MUST be arc sorted if `treat_epsilons_specially` is True. Caution: The rules for assigning the attributes of the output Fsa are as follows: - (1) For attributes where only one source (a_fsa or b_fsa) has that attribute: Copy via arc_map, or use zero if arc_map has -1. This rule works for both floating point and integer attributes. - (2) For attributes where both sources (a_fsa and b_fsa) have that attribute: For floating point attributes: sum via arc_maps, or use zero if arc_map has -1. For integer attributes, it's not supported for now (the attributes will be discarded and will not be kept in the output FSA). Returns: The result of intersecting a_fsa and b_fsa. len(out_fsa.shape) is 2 if and only if the two input FSAs are single FSAs; otherwise, len(out_fsa.shape) is 3. ''' if a_fsa.is_cpu() or b_fsa.is_cpu(): assert a_fsa.properties & fsa_properties.ARC_SORTED != 0 assert b_fsa.properties & fsa_properties.ARC_SORTED != 0 need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect( a_fsa.arcs, a_fsa.properties, b_fsa.arcs, b_fsa.properties, treat_epsilons_specially, need_arc_map) out_fsa = Fsa(ragged_arc) for name, a_value in a_fsa.named_tensor_attr(): if hasattr(b_fsa, name): # Both a_fsa and b_fsa have this attribute. # We only support attributes with dtype `torch.float32`. # Other kinds of attributes are discarded. if a_value.dtype != torch.float32: continue b_value = getattr(b_fsa, name) assert b_value.dtype == torch.float32 value = index_select(a_value, a_arc_map) \ + index_select(b_value, b_arc_map) setattr(out_fsa, name, value) else: # only a_fsa has this attribute, copy it via arc_map value = index(a_value, a_arc_map) setattr(out_fsa, name, value) # now copy tensor attributes that are in b_fsa but are not in a_fsa for name, b_value in b_fsa.named_tensor_attr(): if not hasattr(out_fsa, name): value = index(b_value, b_arc_map) setattr(out_fsa, name, value) for name, a_value in a_fsa.named_non_tensor_attr(): setattr(out_fsa, name, a_value) for name, b_value in b_fsa.named_non_tensor_attr(): if not hasattr(out_fsa, name): setattr(out_fsa, name, b_value) return out_fsa
def intersect(a_fsa: Fsa, b_fsa: Fsa) -> Fsa: '''Compute the intersection of two FSAs on CPU. Args: a_fsa: The first input FSA on CPU. It can be either a single FSA or a FsaVec. b_fsa: The second input FSA on CPU. it can be either a single FSA or a FsaVec. Caution: The two input FSAs MUST be arc sorted. Caution: The rules for assigning the attributes of the output Fsa are as follows: - (1) For attributes where only one source (a_fsa or b_fsa) has that attribute: Copy via arc_map, or use zero if arc_map has -1. This rule works for both floating point and integer attributes. - (2) For attributes where both sources (a_fsa and b_fsa) have that attribute: For floating point attributes: sum via arc_maps, or use zero if arc_map has -1. For integer attributes, it's not supported for now (the attributes will be discarded and will not be kept in the output FSA). Returns: The result of intersecting a_fsa and b_fsa. ''' need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect(a_fsa.arcs, b_fsa.arcs, need_arc_map) # Some of entries in a_arc_map and b_arc_map may be -1. # The arc_maps are incremented so that every entry is non-negative. a_arc_map = a_arc_map.to(torch.int64) + 1 b_arc_map = b_arc_map.to(torch.int64) + 1 out_fsa = Fsa.from_ragged_arc(ragged_arc) for name, a_value in a_fsa.named_tensor_attr(): if hasattr(b_fsa, name): # Both a_fsa and b_fsa have this attribute. # We only support attributes with dtype ``torch.float32``. # Other kinds of attributes are discarded. if a_value.dtype != torch.float32: continue b_value = getattr(b_fsa, name) assert b_value.dtype == torch.float32 # a_arc_map and b_arc_map have been offset by 1 # so we need a padding here padding = a_value.new_zeros((1, *a_value.shape[1:])) a_value = torch.cat((padding, a_value), dim=0) b_value = torch.cat((padding, b_value), dim=0) value = a_value.index_select(0, a_arc_map) \ + b_value.index_select(0, b_arc_map) setattr(out_fsa, name, value) else: # only a_fsa has this attribute, copy it via arc_map padding = a_value.new_zeros((1, *a_value.shape[1:])) a_value = torch.cat((padding, a_value), dim=0) value = a_value.index_select(0, a_arc_map) setattr(out_fsa, name, value) # now copy tensor attributes that are in b_fsa but are not in a_fsa for name, b_value in b_fsa.named_tensor_attr(): if not hasattr(out_fsa, name): padding = b_value.new_zeros((1, *b_value.shape[1:])) b_value = torch.cat((padding, b_value), dim=0) value = b_value.index_select(0, b_arc_map) setattr(out_fsa, name, value) for name, a_value in a_fsa.named_non_tensor_attr(): setattr(out_fsa, name, a_value) for name, b_value in b_fsa.named_non_tensor_attr(): if not hasattr(out_fsa, name): setattr(out_fsa, name, b_value) return out_fsa
def compose(a_fsa: Fsa, b_fsa: Fsa, treat_epsilons_specially: bool = True, inner_labels: str = None) -> Fsa: '''Compute the composition of two FSAs (currently on CPU). Note: If there is no `aux_labels` in the input FSAs, it is equivalent to :func:`k2.intersect`. The difference from :func:`k2.intersect` is when a_fsa has the `aux_labels` attribute set. These are interpreted as output labels (olabels), and the composition involves matching the olabels of a with the ilabels of b. This is implemented by intersecting the inverse of a_fsa (a_fsa_inv) with b_fsa, then replacing the ilabels of the result with the original ilabels on a_fsa which are now the aux_labels of a_fsa_inv. Args: a_fsa: The first input FSA on CPU. It can be either a single FSA or an FsaVec. b_fsa: The second input FSA on CPU. it can be either a single FSA or an FsaVec. treat_epsilons_specially: If True, epsilons will be treated as epsilon, meaning epsilon arcs can match with an implicit epsilon self-loop. If False, epsilons will be treated as real, normal symbols (to have them treated as epsilons in this case you may have to add epsilon self-loops to whichever of the inputs is naturally epsilon-free). inner_labels: If specified (and if a_fsa has `aux_labels`), the labels that we matched on, which would normally be discarded, will instead be copied to this attribute name. Caution: `b_fsa` has to be arc sorted. Caution: The rules for assigning the attributes of the output Fsa are as follows: - (1) For attributes where only one source (a_fsa or b_fsa) has that attribute: Copy via arc_map, or use zero if arc_map has -1. This rule works for both floating point and integer attributes. - (2) For attributes where both sources (a_fsa and b_fsa) have that attribute: For floating point attributes: sum via arc_maps, or use zero if arc_map has -1. For integer attributes, it's not supported for now (the attributes will be discarded and will not be kept in the output FSA). Returns: The result of composing a_fsa and b_fsa. `len(out_fsa.shape)` is 2 if and only if the two input FSAs are single FSAs; otherwise, `len(out_fsa.shape)` is 3. ''' assert a_fsa.is_cpu() assert b_fsa.is_cpu() if not hasattr(a_fsa, 'aux_labels'): return intersect(a_fsa, b_fsa, treat_epsilons_specially) if not hasattr(b_fsa, 'aux_labels'): return intersect(a_fsa, b_fsa, treat_epsilons_specially) assert isinstance(a_fsa.aux_labels, torch.Tensor) a_fsa_inv = arc_sort(a_fsa.invert()) assert b_fsa.properties & fsa_properties.ARC_SORTED != 0 need_arc_map = True ragged_arc, a_arc_map, b_arc_map = _k2.intersect( a_fsa_inv.arcs, a_fsa_inv.properties, b_fsa.arcs, b_fsa.properties, treat_epsilons_specially, need_arc_map) out_fsa = Fsa(ragged_arc) if inner_labels is not None: # out_fsa.`inner_labels` = out_fsa.labels setattr(out_fsa, inner_labels, out_fsa.labels) out_fsa.labels = index(a_fsa_inv.aux_labels, a_arc_map) out_fsa.aux_labels = index(b_fsa.aux_labels, b_arc_map) for name, a_value in a_fsa_inv.named_tensor_attr(): if hasattr(b_fsa, name): # Both a_fsa and b_fsa have this attribute. # We only support attributes with dtype `torch.float32`. # Other kinds of attributes are discarded. if a_value.dtype != torch.float32: continue b_value = getattr(b_fsa, name) assert b_value.dtype == torch.float32 value = index_select(a_value, a_arc_map) + index_select( b_value, b_arc_map) setattr(out_fsa, name, value) else: # only a_fsa has this attribute, copy it via arc_map value = index(a_value, a_arc_map) setattr(out_fsa, name, value) # now copy tensor attributes that are in b_fsa but are not in a_fsa for name, b_value in b_fsa.named_tensor_attr(): if not hasattr(out_fsa, name): value = index(b_value, b_arc_map) setattr(out_fsa, name, value) for name, a_value in a_fsa_inv.named_non_tensor_attr(): if name == 'symbols': continue if name == 'aux_symbols': setattr(out_fsa, 'symbols', a_value) else: setattr(out_fsa, name, a_value) for name, b_value in b_fsa.named_non_tensor_attr(): if not hasattr(out_fsa, name): setattr(out_fsa, name, b_value) return out_fsa