コード例 #1
0
    def test_index_fsa(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            s1 = '''
                0 1 1 0.1
                1 2 -1 0.2
                2
            '''
            s2 = '''
                0 1 -1 1.0
                1
            '''
            fsa1 = k2.Fsa.from_str(s1)
            fsa1.tensor_attr = torch.tensor([10, 20], dtype=torch.int32)
            fsa1.ragged_attr = k2.ragged.create_ragged2([[11, 12],
                                                         [21, 22, 23]])

            fsa2 = k2.Fsa.from_str(s2)
            fsa2.tensor_attr = torch.tensor([100], dtype=torch.int32)
            fsa2.ragged_attr = k2.ragged.create_ragged2([[111]])

            fsa1 = fsa1.to(device)
            fsa2 = fsa2.to(device)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])

            single1 = k2.index_fsa(
                fsa_vec, torch.tensor([0], dtype=torch.int32, device=device))
            assert torch.all(torch.eq(fsa1.tensor_attr, single1.tensor_attr))
            assert str(single1.ragged_attr) == str(fsa1.ragged_attr)
            assert single1.device == device

            single2 = k2.index_fsa(
                fsa_vec, torch.tensor([1], dtype=torch.int32, device=device))
            assert torch.all(torch.eq(fsa2.tensor_attr, single2.tensor_attr))
            assert str(single2.ragged_attr) == str(fsa2.ragged_attr)
            assert single2.device == device

            multiples = k2.index_fsa(
                fsa_vec,
                torch.tensor([0, 1, 0, 1, 1], dtype=torch.int32,
                             device=device))
            assert multiples.shape == (5, None, None)
            assert torch.all(
                torch.eq(
                    multiples.tensor_attr,
                    torch.cat(
                        (fsa1.tensor_attr, fsa2.tensor_attr, fsa1.tensor_attr,
                         fsa2.tensor_attr, fsa2.tensor_attr))))
            assert str(multiples.ragged_attr) == str(
                k2.ragged.append([
                    fsa1.ragged_attr, fsa2.ragged_attr, fsa1.ragged_attr,
                    fsa2.ragged_attr, fsa2.ragged_attr
                ],
                                 axis=0))  # noqa
            assert multiples.device == device
コード例 #2
0
    def _intersect_calc_scores_mmi_exact(
        self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True,
    ):
        device = dense_fsa_vec.device
        assert device == num_graphs.device and device == den_graph.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        den_graph = den_graph.clone()
        num_graphs = num_graphs.clone()

        num_den_graphs = k2.cat([num_graphs, den_graph])

        # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
        # so the following reorders num_den_graphs.

        # [0, 1, 2, ... ]
        num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

        # [num_fsas, num_fsas, num_fsas, ... ]
        den_graph_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

        # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
        num_den_graphs_indexes = torch.stack([num_graphs_indexes, den_graph_indexes]).t().reshape(-1).to(device)

        num_den_reordered_graphs = k2.index_fsa(num_den_graphs, num_den_graphs_indexes)

        # [[0, 1, 2, ...]]
        a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

        # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
        a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

        num_den_lats = k2.intersect_dense(
            a_fsas=num_den_reordered_graphs,
            b_fsas=dense_fsa_vec,
            output_beam=self.intersect_conf.output_beam,
            a_to_b_map=a_to_b_map,
            seqframe_idx_name="seqframe_idx" if return_lats else None,
        )

        num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
        num_tot_scores = num_den_tot_scores[::2]
        den_tot_scores = num_den_tot_scores[1::2]

        if return_lats:
            lat_slice = torch.arange(num_fsas, dtype=torch.int32).to(device) * 2
            return (
                num_tot_scores,
                den_tot_scores,
                k2.index_fsa(num_den_lats, lat_slice),
                k2.index_fsa(num_den_lats, lat_slice + 1),
            )
        else:
            return num_tot_scores, den_tot_scores, None, None
コード例 #3
0
ファイル: mmi_graph.py プロジェクト: desh2608/snowfall
    def compile(self,
                texts: Iterable[str],
                P: k2.Fsa,
                replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
          replicate_den:
            If True, the returned den_graph is replicated to match the number
            of FSAs in the returned num_graph; if False, the returned den_graph
            contains only a single FSA
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph` if replicate_den is True; otherwise, it
              is an FsaVec containing only a single FSA.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)
        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False)
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        if replicate_den:
            indexes = torch.zeros(len(texts),
                                  dtype=torch.int32,
                                  device=self.device)
            den = k2.index_fsa(ctc_topo_P_vec, indexes)
        else:
            den = ctc_topo_P_vec

        return num, den
コード例 #4
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph, decoding_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.
              It is the result of compose(ctc_topo, P, L, transcript)

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
              It is the result of compose(ctc_topo, P).

            - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G)
              Note that it is a single Fsa, not an FsaVec.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()
        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)

        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False,
                         inner_labels='phones')
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        indexes = torch.zeros(len(texts),
                              dtype=torch.int32,
                              device=self.device)
        den = k2.index_fsa(ctc_topo_P_vec, indexes)

        return num, den, self.decoding_graph
コード例 #5
0
def _intersect_device(
    a_fsas: k2.Fsa,
    b_fsas: k2.Fsa,
    b_to_a_map: torch.Tensor,
    sorted_match_a: bool,
    batch_size: int = 500,
):
    """Wrap k2.intersect_device

    This is a wrapper of k2.intersect_device and its purpose is to split
    b_fsas into several batches and process each batch separately to avoid
    CUDA OOM error.
    The arguments and return value of this function are the same as
    k2.intersect_device.

    NOTE: You can decrease batch_size in case of CUDA out of memory error.
    """
    num_fsas = b_fsas.shape[0]
    if num_fsas <= batch_size:
        return k2.intersect_device(
            a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
        )

    num_batches = int(math.ceil(float(num_fsas) / batch_size))
    splits = []
    for i in range(num_batches):
        start = i * batch_size
        end = min(start + batch_size, num_fsas)
        splits.append((start, end))

    ans = []
    for start, end in splits:
        indexes = torch.arange(start, end).to(b_to_a_map)

        fsas = k2.index_fsa(b_fsas, indexes)
        b_to_a = k2.index_select(b_to_a_map, indexes)
        path_lats = k2.intersect_device(
            a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
        )
        ans.append(path_lats)

    return k2.cat(ans)
コード例 #6
0
ファイル: nbest.py プロジェクト: k2-fsa/k2
    def top_k(self, k: int) -> 'Nbest':
        '''Get a subset of paths in the Nbest. The resulting Nbest is regular
        in that each sequence (i.e., utterance) has the same number of
        paths (k).

        We select the top-k paths according to the total_scores of each path.
        If a utterance has less than k paths, then its last path, after sorting
        by tot_scores in descending order, is repeated so that each utterance
        has exactly k paths.

        Args:
          k:
            Number of paths in each utterance.
        Returns:
          Return a new Nbest with a regular shape.
        '''
        ragged_scores = self.total_scores()

        # indexes contains idx01's for self.shape
        # ragged_scores.values()[indexes] is sorted
        indexes = k2.ragged.sort_sublist(ragged_scores,
                                         descending=True,
                                         need_new2old_indexes=True)

        ragged_indexes = k2.RaggedInt(self.shape, indexes)

        padded_indexes = k2.ragged.pad(ragged_indexes,
                                       mode='replicate',
                                       value=-1)
        assert torch.ge(padded_indexes, 0).all(), \
                'Some utterances contain empty ' \
                f'n-best: {self.shape.row_splits(1)}'

        # Select the idx01's of top-k paths of each utterance
        top_k_indexes = padded_indexes[:, :k].flatten().contiguous()

        top_k_fsas = k2.index_fsa(self.fsa, top_k_indexes)

        top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(),
                                                     dim1=k)
        return Nbest(top_k_fsas, top_k_shape)
コード例 #7
0
    def test(self):
        s0 = '''
            0 1 1 0.1
            0 2 2 0.2
            1 2 3 0.3
            2 3 -1 0.4
            3
        '''
        s1 = '''
            0 1 -1 0.5
            1
        '''
        s2 = '''
            0 2 1 0.6
            0 1 2 0.7
            1 3 -1 0.8
            2 1 3 0.9
            3
        '''
        for device in self.devices:
            fsa0 = k2.Fsa.from_str(s0).to(device).requires_grad_(True)
            fsa1 = k2.Fsa.from_str(s1).to(device).requires_grad_(True)
            fsa2 = k2.Fsa.from_str(s2).to(device).requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2])

            new_fsa21 = k2.index_fsa(
                fsa_vec, torch.tensor([2, 1], dtype=torch.int32,
                                      device=device))
            assert new_fsa21.shape == (2, None, None)
            assert torch.all(
                torch.eq(
                    new_fsa21.arcs.values()[:, :3],
                    torch.tensor([
                        # fsa 2
                        [0, 2, 1],
                        [0, 1, 2],
                        [1, 3, -1],
                        [2, 1, 3],
                        # fsa 1
                        [0, 1, -1]
                    ]).to(torch.int32).to(device)))

            scale = torch.arange(new_fsa21.scores.numel(), device=device)
            (new_fsa21.scores * scale).sum().backward()
            assert torch.allclose(fsa0.scores.grad,
                                  torch.tensor([0., 0, 0, 0], device=device))
            assert torch.allclose(fsa1.scores.grad,
                                  torch.tensor([4.], device=device))
            assert torch.allclose(
                fsa2.scores.grad, torch.tensor([0., 1., 2., 3.],
                                               device=device))

            # now select only a single FSA
            fsa0.scores.grad = None
            fsa1.scores.grad = None
            fsa2.scores.grad = None

            new_fsa0 = k2.index_fsa(
                fsa_vec, torch.tensor([0], dtype=torch.int32, device=device))
            assert new_fsa0.shape == (1, None, None)

            scale = torch.arange(new_fsa0.scores.numel(), device=device)
            (new_fsa0.scores * scale).sum().backward()
            assert torch.allclose(
                fsa0.scores.grad, torch.tensor([0., 1., 2., 3.],
                                               device=device))
            assert torch.allclose(fsa1.scores.grad,
                                  torch.tensor([0.], device=device))
            assert torch.allclose(
                fsa2.scores.grad, torch.tensor([0., 0., 0., 0.],
                                               device=device))
コード例 #8
0
ファイル: graph_decoders.py プロジェクト: quuhua911/NeMo
    def decode(
        self,
        log_probs: torch.Tensor,
        log_probs_length: torch.Tensor,
        return_lattices: bool = False,
        return_ilabels: bool = False,
        output_aligned: bool = True,
    ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]:
        if self.decoding_graph is None:
            self.decoding_graph = self.base_graph

        if self.blank != 0:
            # rearrange log_probs to put blank at the first place
            # and shift targets to emulate blank = 0
            log_probs, _ = make_blank_first(self.blank, log_probs, None)
        supervisions, order = create_supervision(log_probs_length)
        if self.decoding_graph.shape[0] > 1:
            self.decoding_graph = k2.index_fsa(self.decoding_graph, order).to(device=log_probs.device)

        if log_probs.device != self.device:
            self.to(log_probs.device)
        dense_fsa_vec = (
            prep_padded_densefsavec(log_probs, supervisions)
            if self.pad_fsavec
            else k2.DenseFsaVec(log_probs, supervisions)
        )

        if self.intersect_pruned:
            lats = k2.intersect_dense_pruned(
                a_fsas=self.decoding_graph,
                b_fsas=dense_fsa_vec,
                search_beam=self.intersect_conf.search_beam,
                output_beam=self.intersect_conf.output_beam,
                min_active_states=self.intersect_conf.min_active_states,
                max_active_states=self.intersect_conf.max_active_states,
            )
        else:
            indices = torch.zeros(dense_fsa_vec.dim0(), dtype=torch.int32, device=self.device)
            dec_graphs = (
                k2.index_fsa(self.decoding_graph, indices)
                if self.decoding_graph.shape[0] == 1
                else self.decoding_graph
            )
            lats = k2.intersect_dense(dec_graphs, dense_fsa_vec, self.intersect_conf.output_beam)
        if self.pad_fsavec:
            shift_labels_inpl([lats], -1)
        self.decoding_graph = None

        if return_lattices:
            lats = k2.index_fsa(lats, invert_permutation(order).to(device=log_probs.device))
            if self.blank != 0:
                # change only ilabels
                # suppose self.blank == self.num_classes - 1
                lats.labels = torch.where(lats.labels == 0, self.blank, lats.labels - 1)
            return lats
        else:
            shortest_path_fsas = k2.index_fsa(
                k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device),
            )
            shortest_paths = []
            probs = []
            # direct iterating does not work as expected
            for i in range(shortest_path_fsas.shape[0]):
                shortest_path_fsa = shortest_path_fsas[i]
                labels = (
                    shortest_path_fsa.labels[:-1].to(dtype=torch.long)
                    if return_ilabels
                    else shortest_path_fsa.aux_labels[:-1].to(dtype=torch.long)
                )
                if self.blank != 0:
                    # suppose self.blank == self.num_classes - 1
                    labels = torch.where(labels == 0, self.blank, labels - 1)
                if not return_ilabels and not output_aligned:
                    labels = labels[labels != self.blank]
                shortest_paths.append(labels[::2] if self.pad_fsavec else labels)
                probs.append(get_arc_weights(shortest_path_fsa)[:-1].to(device=log_probs.device).exp())
            return shortest_paths, probs