Esempio n. 1
0
File: utils.py Progetto: aarora8/k2
def create_fsa_vec(fsas):
    '''Create an FsaVec from a list of FSAs

    We use the following rules to set the attributes of the output FsaVec:

    - For tensor attributes, we assume that all input FSAs have the same
      attribute name and the values are concatenated.

    - For non-tensor attributes, if any two of the input FSAs have the same
      attribute name, then we assume that their attribute values are equal and
      the output FSA will inherit the attribute.

    Args:
      fsas:
        A list of `Fsa`. Each element must be a single FSA.

    Returns:
      An instance of :class:`Fsa` that represents a FsaVec.
    '''
    ragged_arc_list = list()
    for fsa in fsas:
        assert len(fsa.shape) == 2
        ragged_arc_list.append(fsa.arcs)

    ragged_arcs = _k2.create_fsa_vec(ragged_arc_list)
    fsa_vec = Fsa(ragged_arcs)

    tensor_attr_names = set(
        name for name, _ in fsa.named_tensor_attr() for fsa in fsas)
    for name in tensor_attr_names:
        values = []
        for fsa in fsas:
            values.append(getattr(fsa, name))
        if isinstance(values[0], torch.Tensor):
            value = torch.cat(values)
        else:
            assert isinstance(values[0], k2.RaggedTensor)
            value = k2.ragged.cat(values, axis=0)
        setattr(fsa_vec, name, value)

    non_tensor_attr_names = set()
    for fsa in fsas:
        for name, _ in fsa.named_non_tensor_attr():
            non_tensor_attr_names.add(name)

    for name in non_tensor_attr_names:
        if name == 'properties':
            continue

        for fsa in fsas:
            value = getattr(fsa, name, None)
            if value is not None:
                if hasattr(fsa_vec, name):
                    assert getattr(fsa_vec, name) == value
                else:
                    setattr(fsa_vec, name, value)
    return fsa_vec
Esempio n. 2
0
    def test_compose(self):
        s = '''
            0 1 11 1 1.0
            0 2 12 2 2.5
            1 3 -1 -1 0
            2 3 -1 -1 2.5
            3
        '''
        a_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        s = '''
            0 1 1 1 1.0
            0 2 2 3 3.0
            1 2 3 2 2.5
            2 3 -1 -1 2.0
            3
        '''
        b_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        ans = k2.compose(a_fsa, b_fsa, inner_labels='inner')
        ans = k2.connect(ans)

        # Convert a single FSA to a FsaVec.
        # It will retain `requires_grad_` of `ans`.
        ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs])

        scores = k2.get_tot_scores(ans,
                                   log_semiring=True,
                                   use_double_scores=False)
        # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad`
        # are computed using GTN.
        # See https://bit.ly/3heLAJq
        assert scores.item() == 10
        scores.backward()
        assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        print(ans)