Beispiel #1
0
    def _intersect_calc_scores_mmi_exact(
        self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True,
    ):
        device = dense_fsa_vec.device
        assert device == num_graphs.device and device == den_graph.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        den_graph = den_graph.clone()
        num_graphs = num_graphs.clone()

        num_den_graphs = k2.cat([num_graphs, den_graph])

        # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
        # so the following reorders num_den_graphs.

        # [0, 1, 2, ... ]
        num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

        # [num_fsas, num_fsas, num_fsas, ... ]
        den_graph_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

        # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
        num_den_graphs_indexes = torch.stack([num_graphs_indexes, den_graph_indexes]).t().reshape(-1).to(device)

        num_den_reordered_graphs = k2.index_fsa(num_den_graphs, num_den_graphs_indexes)

        # [[0, 1, 2, ...]]
        a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

        # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
        a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

        num_den_lats = k2.intersect_dense(
            a_fsas=num_den_reordered_graphs,
            b_fsas=dense_fsa_vec,
            output_beam=self.intersect_conf.output_beam,
            a_to_b_map=a_to_b_map,
            seqframe_idx_name="seqframe_idx" if return_lats else None,
        )

        num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
        num_tot_scores = num_den_tot_scores[::2]
        den_tot_scores = num_den_tot_scores[1::2]

        if return_lats:
            lat_slice = torch.arange(num_fsas, dtype=torch.int32).to(device) * 2
            return (
                num_tot_scores,
                den_tot_scores,
                k2.index_fsa(num_den_lats, lat_slice),
                k2.index_fsa(num_den_lats, lat_slice + 1),
            )
        else:
            return num_tot_scores, den_tot_scores, None, None
Beispiel #2
0
    def test_cat_fsa_vec(self):
        for device in self.devices:
            s = '''
                0 1 1 0.1
                0 1 2 0.2
                1 2 -1 0.3
                2
            '''
            fsa1 = k2.Fsa.from_str(s).to(device)
            fsa1.tensor_attr1 = torch.tensor([1, 2, 3]).to(device)
            fsa1.tensor_attr2 = torch.tensor([4, 5, 6]).to(device)
            fsa1.non_tensor_attr1 = 'fsa1'

            fsa1.ragged_tensor_attr1 = \
                    k2.RaggedTensor('[[1 2] [] [3 4 5]]').to(device)
            fsa1.ragged_tensor_attr2 = \
                    k2.RaggedTensor('[[1 20] [30] [5]]').to(device)

            fsa2 = k2.Fsa.from_str(s).to(device)
            fsa2.tensor_attr1 = torch.tensor([10, 20, 30]).to(device)
            fsa2.tensor_attr3 = torch.tensor([40, 50, 60]).to(device)
            fsa2.non_tensor_attr1 = 'fsa'
            fsa2.non_tensor_attr2 = 'fsa2'

            fsa2.ragged_tensor_attr1 = \
                    k2.RaggedTensor('[[3] [4 5] [6 7]]').to(device)
            fsa2.ragged_tensor_attr3 = \
                    k2.RaggedTensor('[[1 0] [0] [-1]]').to(device)

            fsa_vec1 = k2.create_fsa_vec([fsa1])
            fsa_vec2 = k2.create_fsa_vec([fsa2])
            fsa_vec = k2.cat([fsa_vec1, fsa_vec2])

            assert str(fsa_vec[0].arcs) == str(fsa1.arcs)
            assert str(fsa_vec[1].arcs) == str(fsa2.arcs)
            assert not hasattr(fsa_vec, 'tensor_attr2')
            assert not hasattr(fsa_vec, 'tensor_attr3')

            assert fsa_vec.non_tensor_attr1 == fsa1.non_tensor_attr1
            assert fsa_vec.non_tensor_attr2 == fsa2.non_tensor_attr2
            assert torch.all(
                torch.eq(fsa_vec.tensor_attr1,
                         torch.tensor([1, 2, 3, 10, 20, 30]).to(device)))

            assert fsa_vec.ragged_tensor_attr1 == k2.RaggedTensor([
                [1, 2],
                [],
                [3, 4, 5],
                [3],
                [4, 5],
                [6, 7],
            ]).to(device)

            assert not hasattr(fsa_vec, 'ragged_tensor_attr2')
            assert not hasattr(fsa_vec, 'ragged_tensor_attr3')
Beispiel #3
0
def _intersect_device(
    a_fsas: k2.Fsa,
    b_fsas: k2.Fsa,
    b_to_a_map: torch.Tensor,
    sorted_match_a: bool,
    batch_size: int = 500,
):
    """Wrap k2.intersect_device

    This is a wrapper of k2.intersect_device and its purpose is to split
    b_fsas into several batches and process each batch separately to avoid
    CUDA OOM error.
    The arguments and return value of this function are the same as
    k2.intersect_device.

    NOTE: You can decrease batch_size in case of CUDA out of memory error.
    """
    num_fsas = b_fsas.shape[0]
    if num_fsas <= batch_size:
        return k2.intersect_device(
            a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
        )

    num_batches = int(math.ceil(float(num_fsas) / batch_size))
    splits = []
    for i in range(num_batches):
        start = i * batch_size
        end = min(start + batch_size, num_fsas)
        splits.append((start, end))

    ans = []
    for start, end in splits:
        indexes = torch.arange(start, end).to(b_to_a_map)

        fsas = k2.index_fsa(b_fsas, indexes)
        b_to_a = k2.index_select(b_to_a_map, indexes)
        path_lats = k2.intersect_device(
            a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
        )
        ans.append(path_lats)

    return k2.cat(ans)
Beispiel #4
0
def _compute_mmi_loss_exact_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    The function name contains `exact`, which means it uses a version of
    intersection without pruning.

    `optimized` in the function name means this function is optimized
    in that it calls k2.intersect_dense only once

    Note:
      It is faster at the cost of using more memory.

    Args:
      nnet_output:
        A 3-D tensor of shape [N, T, C]
      texts:
        The transcript. Each element consists of space(s) separated words.
      supervision_segments:
        A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`.
      graph_compiler:
        Used to build num_graphs and den_graphs
      P:
        Represents a bigram Fsa.
      den_scale:
        The scale applied to the denominator tot_scores.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    device = num_graphs.device

    num_fsas = num_graphs.shape[0]
    assert dense_fsa_vec.dim0() == num_fsas

    assert den_graphs.shape[0] == 1

    # the aux_labels of num_graphs is k2.RaggedInt
    # but it is torch.Tensor for den_graphs.
    #
    # The following converts den_graphs.aux_labels
    # from torch.Tensor to k2.RaggedInt so that
    # we can use k2.append() later
    den_graphs.convert_attr_to_ragged_(name='aux_labels')

    # The motivation to concatenate num_graphs and den_graphs
    # is to reduce the number of calls to k2.intersect_dense.
    num_den_graphs = k2.cat([num_graphs, den_graphs])

    # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
    # so the following reorders num_den_graphs.
    #
    # The following code computes a_to_b_map

    # [0, 1, 2, ... ]
    num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

    # [num_fsas, num_fsas, num_fsas, ... ]
    den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

    # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
    num_den_graphs_indexes = torch.stack(
        [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)

    num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)

    # [[0, 1, 2, ...]]
    a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

    # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
    a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

    num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                      dense_fsa_vec,
                                      output_beam=10.0,
                                      a_to_b_map=a_to_b_map)

    num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True,
                                                     use_double_scores=True)

    num_tot_scores = num_den_tot_scores[::2]
    den_tot_scores = num_den_tot_scores[1::2]

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Beispiel #5
0
    def forward(
            self, nnet_output: torch.Tensor, texts: List,
            supervision_segments: torch.Tensor
    ) -> Tuple[torch.Tensor, int, int]:
        num_graphs, den_graphs = self.graph_compiler.compile(
            texts, self.P, replicate_den=False)

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

        device = num_graphs.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        assert den_graphs.shape[0] == 1

        # the aux_labels of num_graphs is k2.RaggedInt
        # but it is torch.Tensor for den_graphs.
        #
        # The following converts den_graphs.aux_labels
        # from torch.Tensor to k2.RaggedInt so that
        # we can use k2.append() later
        den_graphs.convert_attr_to_ragged_(name='aux_labels')

        num_den_graphs = k2.cat([num_graphs, den_graphs])

        # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
        # so the following reorders num_den_graphs.

        # [0, 1, 2, ... ]
        num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

        # [num_fsas, num_fsas, num_fsas, ... ]
        den_graphs_indexes = torch.tensor([num_fsas] * num_fsas,
                                          dtype=torch.int32)

        # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
        num_den_graphs_indexes = torch.stack(
            [num_graphs_indexes,
             den_graphs_indexes]).t().reshape(-1).to(device)

        num_den_reordered_graphs = k2.index(num_den_graphs,
                                            num_den_graphs_indexes)

        # [[0, 1, 2, ...]]
        a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

        # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
        a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

        num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                          dense_fsa_vec,
                                          output_beam=10.0,
                                          a_to_b_map=a_to_b_map)

        num_den_tot_scores = num_den_lats.get_tot_scores(
            log_semiring=True, use_double_scores=True)

        num_tot_scores = num_den_tot_scores[::2]
        den_tot_scores = num_den_tot_scores[1::2]

        tot_scores = num_tot_scores - self.den_scale * den_tot_scores
        tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
            tot_scores, supervision_segments[:, 2])
        return tot_score, tot_frames, all_frames