예제 #1
0
파일: fsa_algo.py 프로젝트: entn-at/k2
def expand_ragged_attributes(
    fsas: Fsa,
    ret_arc_map: bool = False,
    ragged_attribute_names: Optional[List[str]] = None
) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]:  # noqa
    '''
    Turn ragged labels attached to this FSA into linear (Tensor) labels,
    expanding arcs into sequences of arcs as necessary to achieve this.
    Supports autograd.  If `fsas` had no ragged attributes, returns `fsas`
    itself.

    Caution: this function will ensure that for final-arcs in the returned
    fsa, the corresponding labels for all ragged attributes are -1; it will
    add an extra arc at the end is necessary to ensure this, if the
    original ragged attributes did not have -1 as their final element on
    final-arcs (note: our intention is that -1's on final arcs, like filler
    symbols, are removed when making attributes ragged; this is what
    fsa_from_unary_function_ragged() does if remove_filler==True (the
    default).

         fsas:   The source Fsa
         ret_arc_map:  if true, will return a pair (new_fsas, arc_map)
              with `arc_map` a tensor of int32 that maps from arcs in the
              result to arcs in `fsas`, with -1's for newly created arcs.
              If false, just returns new_fsas.
         ragged_attribute_names:  If specified, just this list of ragged
              attributes will be expanded to linear tensor attributes, and
              the rest will stay ragged.
    '''
    if ragged_attribute_names is None:
        ragged_attribute_tensors = []
        ragged_attribute_names = []
        for name, value in fsas.named_tensor_attr(include_scores=False):
            if isinstance(value, k2.RaggedInt):
                ragged_attribute_tensors.append(value)
                ragged_attribute_names.append(name)
    else:
        ragged_attribute_tensors = [
            getattr(fsas, name) for name in ragged_attribute_names
        ]
        for t in ragged_attribute_tensors:
            assert isinstance(t, k2.RaggedInt)

    if len(ragged_attribute_tensors) == 0:
        if ret_arc_map:
            arc_map = torch.arange(fsas.num_arcs,
                                   dtype=torch.int32,
                                   device=fsas.device)
            return (fsas, arc_map)
        else:
            return fsas

    (dest_arcs, dest_labels,
     arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors)

    # The rest of this function is a modified version of
    # `fsa_from_unary_function_tensor()`.
    dest = Fsa(dest_arcs)

    # Handle the non-ragged attributes, and ragged attributes that
    # we're not linearizing.
    for name, value in fsas.named_tensor_attr(include_scores=False):
        if isinstance(value, torch.Tensor):
            filler = float(fsas.get_filler(name))
            setattr(dest, name,
                    index_select(value, arc_map, default_value=filler))
        elif name not in ragged_attribute_names:
            setattr(dest, name, index(value, arc_map))

    # Handle the attributes that were ragged but are now linear
    for name, value in zip(ragged_attribute_names, dest_labels):
        setattr(dest, name, value)

    # Copy non-tensor attributes
    for name, value in fsas.named_non_tensor_attr():
        setattr(dest, name, value)

    # make sure autograd works on the scores
    k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map)

    # Make sure -1's are only on final-arcs, and never on non-final arcs.
    if hasattr(dest, 'aux_labels'):
        _k2.fix_final_labels(dest.arcs, dest.aux_labels)

    if ret_arc_map:
        return dest, arc_map
    else:
        return dest
예제 #2
0
    def test_final(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available() and k2.with_cuda:
            devices.append(torch.device('cuda', 0))

        for device in devices:
            for need_map in [True, False]:
                s = '''
                0 1 2 10
                0 1 1 20
                1 2 -1 30
                2
                '''
                src = k2.Fsa.from_str(s).to(device).requires_grad_(True)
                src.float_attr = torch.tensor([0.1, 0.2, 0.3],
                                              dtype=torch.float32,
                                              requires_grad=True,
                                              device=device)
                src.int_attr = torch.tensor([1, 2, 3],
                                            dtype=torch.int32,
                                            device=device)
                src.ragged_attr = k2.RaggedInt('[[1 2 3] [5 6] [1]]').to(
                    device)

                src.attr1 = 'src'
                src.attr2 = 'fsa'

                if need_map:
                    dest, arc_map = k2.expand_ragged_attributes(
                        src, ret_arc_map=True)
                else:
                    dest = k2.expand_ragged_attributes(src)

                assert torch.allclose(
                    dest.float_attr,
                    torch.tensor([0.1, 0.2, 0.0, 0.0, 0.0, 0.3, 0.0],
                                 dtype=torch.float32,
                                 device=device))
                assert torch.all(
                    torch.eq(
                        dest.scores,
                        torch.tensor([10, 20, 0, 0, 0, 30, 0],
                                     dtype=torch.float32,
                                     device=device)))
                assert torch.all(
                    torch.eq(
                        dest.int_attr,
                        torch.tensor([1, 2, 0, 0, 0, 3, 0],
                                     dtype=torch.int32,
                                     device=device)))
                _k2.fix_final_labels(dest.arcs, dest.int_attr)
                assert torch.all(
                    torch.eq(
                        dest.int_attr,
                        torch.tensor([1, 2, 0, 0, 0, 3, -1],
                                     dtype=torch.int32,
                                     device=device)))

                assert torch.all(
                    torch.eq(
                        dest.ragged_attr,
                        torch.tensor([1, 5, 2, 3, 6, 1, -1],
                                     dtype=torch.float32,
                                     device=device)))

                # non-tensor attributes...
                assert dest.attr1 == src.attr1
                assert dest.attr2 == src.attr2

                # now for autograd
                scale = torch.tensor([10, 20, 10, 10, 10, 30, 10],
                                     device=device)
                (dest.float_attr * scale).sum().backward()
                (dest.scores * scale).sum().backward()

                expected_grad = torch.tensor([10, 20, 30],
                                             dtype=torch.float32,
                                             device=device)

                assert torch.all(torch.eq(src.float_attr.grad, expected_grad))
                assert torch.all(torch.eq(src.scores.grad, expected_grad))