Пример #1
0
    def _intersect_calc_scores_mmi_pruned(
        self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True,
    ):
        device = dense_fsa_vec.device
        assert device == num_graphs.device and device == den_graph.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        num_lats = k2.intersect_dense(
            a_fsas=num_graphs,
            b_fsas=dense_fsa_vec,
            output_beam=self.intersect_conf.output_beam,
            seqframe_idx_name="seqframe_idx" if return_lats else None,
        )
        den_lats = k2.intersect_dense_pruned(
            a_fsas=den_graph,
            b_fsas=dense_fsa_vec,
            search_beam=self.intersect_conf.search_beam,
            output_beam=self.intersect_conf.output_beam,
            min_active_states=self.intersect_conf.min_active_states,
            max_active_states=self.intersect_conf.max_active_states,
            seqframe_idx_name="seqframe_idx" if return_lats else None,
        )

        # use_double_scores=True does matter
        # since otherwise it sometimes makes rounding errors
        num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
        den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)

        if return_lats:
            return num_tot_scores, den_tot_scores, num_lats, den_lats
        else:
            return num_tot_scores, den_tot_scores, None, None
Пример #2
0
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch['inputs']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             torch.floor_divide(supervisions['start_frame'],
                                model.subsampling_factor),
             torch.floor_divide(supervisions['num_frames'],
                                model.subsampling_factor)), 1).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        #  blank_bias = -3.0
        #  nnet_output[:, :, 0] += blank_bias

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert LG.is_cuda()
        assert HLG.device == nnet_output.device, \
            f"Check failed: LG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
                                             10000)

        # lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == len(texts)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            hyp_words = [symbols.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
                    batch_idx, num_cuts, tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100))

        num_cuts += len(texts)

    return results
Пример #3
0
    def test_two_dense(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''
        for device in self.devices:
            fsa = k2.Fsa.from_str(s).to(device)
            fsa.requires_grad_(True)
            fsa_vec = k2.create_fsa_vec([fsa])
            log_prob = torch.tensor(
                [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
                 [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
                dtype=torch.float32,
                device=device,
                requires_grad=True)

            supervision_segments = torch.tensor([[0, 0, 2], [1, 0, 3]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            out_fsa = k2.intersect_dense_pruned(fsa_vec,
                                                dense_fsa_vec,
                                                search_beam=100000,
                                                output_beam=100000,
                                                min_active_states=0,
                                                max_active_states=10000,
                                                seqframe_idx_name='seqframe',
                                                frame_idx_name='frame')
            assert torch.all(
                torch.eq(out_fsa.seqframe,
                         torch.tensor([0, 1, 2, 3, 4, 5, 6], device=device)))

            assert torch.all(
                torch.eq(out_fsa.frame,
                         torch.tensor([0, 1, 2, 0, 1, 2, 3], device=device)))

            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = out_fsa.get_tot_scores(log_semiring=False,
                                            use_double_scores=False)
            scores.sum().backward()

            # `expected` results are computed using gtn.
            # See https://bit.ly/3oYObeb
            expected_scores_out_fsa = torch.tensor(
                [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0], device=device)
            expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0],
                                             device=device)
            expected_grad_log_prob = torch.tensor([
                0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0,
                0.0, 0.0, 0.0, 1.0
            ]).reshape_as(log_prob).to(device)

            assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
            assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
            assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Пример #4
0
    def test_two_fsas_long_pruned(self):
        # as test_two_fsas_long in intersect_dense_test.py,
        # but with pruned intersection
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            fsa1 = k2.Fsa.from_str(s1)
            fsa2 = k2.Fsa.from_str(s2)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])
            log_prob = torch.rand((2, 100, 3),
                                  dtype=torch.float32,
                                  device=device,
                                  requires_grad=True)

            supervision_segments = torch.tensor([[0, 1, 95], [1, 20, 50]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            fsa_vec = fsa_vec.to(device)
            out_fsa = k2.intersect_dense_pruned(fsa_vec,
                                                dense_fsa_vec,
                                                search_beam=100,
                                                output_beam=100,
                                                min_active_states=1,
                                                max_active_states=10,
                                                seqframe_idx_name='seqframe',
                                                frame_idx_name='frame')

            expected_seqframe = torch.arange(96).to(torch.int32).to(device)
            assert torch.allclose(out_fsa.seqframe, expected_seqframe)

            # the second output FSA is empty since there is no self-loop in fsa2
            assert torch.allclose(out_fsa.frame, expected_seqframe)

            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = out_fsa.get_tot_scores(log_semiring=False,
                                            use_double_scores=False)
            scores.sum().backward()
Пример #5
0
def decode(dataloader: torch.utils.data.DataLoader,
           model: None,
           device: Union[str, torch.device],
           ctc_topo: None,
           numericalizer=None,
           num_paths=-1,
           output_beam_size: float=8):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []
    for batch_idx, batch in enumerate(dataloader):
        assert isinstance(batch, dict), type(batch)
        feature = batch['inputs']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
             (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32)
        supervision_segments = torch.clamp(supervision_segments, min=0)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        nnet_output = nnet_output.permute(0, 2, 1)

        # TODO(Liyong Guo): Tune this bias
        # blank_bias = 0.0
        # nnet_output[:, :, 0] += blank_bias

        with torch.no_grad():
            dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

            lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0,
                                                 output_beam_size, 30, 10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            pieces = [numericalizer.tokens_list[token_id] for token_id in hyps[i]]
            hyp_words = numericalizer.tokenizer.DecodePieces(pieces).split(' ')
            ref_words = texts[i].split(' ')
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
                    batch_idx, num_cuts, tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100))
        num_cuts += len(texts)
    return results
Пример #6
0
def get_objf(batch, model, device, L, symbols, training, optimizer=None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'], supervisions['start_frame'],
         supervisions['num_frames']), 1).to(torch.int32)
    texts = supervisions['text']
    assert feature.ndim == 3
    #print(feature.shape)
    #print(supervision_segments[:, 1] + supervision_segments[:, 2])

    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    feature = feature.to(device)
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    # TODO(haowen): create decoding graph at the beginning of training
    decoding_graph = create_decoding_graph(texts, L, symbols)
    decoding_graph.to_(device)
    decoding_graph.scores.requires_grad_(False)
    #print(nnet_output.shape)
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    #dense_fsa_vec.scores.requires_grad_(True)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device
    #print(nnet_output.get_device())
    print(decoding_graph.arcs)
    print(dense_fsa_vec.dense_fsa_vec)
    target_graph = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 10,
                                             10000, 0)
    tot_scores = -k2.get_tot_scores(target_graph, True, False).sum()
    if training:
        optimizer.zero_grad()
        tot_scores.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    objf = tot_scores.detach().cpu()
    total_objf = objf.item()
    total_frames = nnet_output.shape[0]

    return total_objf, total_frames
Пример #7
0
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable):
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch['features']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             torch.floor_divide(supervisions['start_frame'],
                                model.subsampling_factor),
             torch.floor_divide(supervisions['num_frames'],
                                model.subsampling_factor)), 1).to(torch.int32)
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        assert LG.is_cuda()
        assert LG.device == nnet_output.device, \
            f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0,
                                             30, 300)
        best_paths = k2.shortest_path(lattices, use_float_scores=True)
        best_paths = best_paths.to('cpu')
        assert best_paths.shape[0] == len(texts)

        for i in range(len(texts)):
            hyp_words = [
                symbols.get(x) for x in best_paths[i].aux_labels if x > 0
            ]
            results.append((texts[i].split(' '), hyp_words))

        if batch_idx % 10 == 0:
            logging.info('Processed batch {}/{} ({:.6f}%)'.format(
                batch_idx, len(dataloader),
                float(batch_idx) / len(dataloader) * 100))

    return results
Пример #8
0
def _compute_mmi_loss_pruned(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    See :func:`_compute_mmi_loss_exact_optimized` for the meaning
    of the arguments.

    `pruned` means it uses k2.intersect_dense_pruned

    Note:
      It uses the least amount of memory, but the loss is not exact due
      to pruning.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)

    # the values for search_beam/output_beam/min_active_states/max_active_states
    # are not tuned. You may want to tune them.
    den_lats = k2.intersect_dense_pruned(den_graphs,
                                         dense_fsa_vec,
                                         search_beam=20.0,
                                         output_beam=7.0,
                                         min_active_states=30,
                                         max_active_states=10000)

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Пример #9
0
    def decode(
            self, cuts: Union[AnyCut,
                              CutSet]) -> List[Tuple[List[str], List[str]]]:
        """
        Perform decoding with an n-gram language model (HLG graph).
        Doesn't support rescoring at this time.
        """
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        word_results = []
        # Hacky way to get batch quickly... we may need to improve on this.
        batch = K2SpeechRecognitionDataset(cuts,
                                           input_strategy=OnTheFlyFeatures(
                                               self.extractor),
                                           check_inputs=False)[list(cuts.ids)]
        features = batch['inputs'].permute(0, 2, 1).to(
            self.device)  # (B, T, F) -> (B, F, T)
        supervision_segments, texts = encode_supervisions(
            batch['supervisions'])

        # Forward pass through the acoustic model
        posteriors, _, _ = self.model(features)
        posteriors = posteriors.permute(0, 2, 1)  # (B, F, T) -> (B, T, F)

        # Wrapping into k2 "dense FSA" (representing PPG as a dense graph)
        dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments)

        # The actual decoding starts here:
        # First, we intersect the HLG and the PPG
        # with default pruning/beam search params from snowfall
        # The result is a batch of graphs (lattices)
        lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8,
                                             30, 10000)
        # ... then we find the shortest paths in the lattices ...
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        # ... and convert them to words with a convenience wrapper from snowfall
        hyps = get_texts(best_paths, torch.arange(len(texts)))

        # Here we read out the words from the best path graphs
        for i in range(len(texts)):
            hyp_words = [self.lexicon.words.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            word_results.append((ref_words, hyp_words))
        return word_results
Пример #10
0
    def test_simple(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''
        fsa = k2.Fsa.from_str(s)
        fsa.requires_grad_(True)
        fsa_vec = k2.create_fsa_vec([fsa])
        log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]],
                                dtype=torch.float32,
                                requires_grad=True)

        supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
        out_fsa = k2.intersect_dense_pruned(fsa_vec,
                                            dense_fsa_vec,
                                            search_beam=100000,
                                            output_beam=100000,
                                            min_active_states=0,
                                            max_active_states=10000)
        scores = k2.get_tot_scores(out_fsa,
                                   log_semiring=False,
                                   use_float_scores=True)
        scores.sum().backward()

        # `expected` results are computed using gtn.
        # See https://bit.ly/3oYObeb
        expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0])
        expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0])
        expected_grad_log_prob = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0,
                                               1.0]).reshape_as(log_prob)
        assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
        assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
        assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Пример #11
0
    def __call__(
        self, batch: Dict[str, Union[torch.Tensor, np.ndarray]]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            batch: Input speech data and corresponding lengths
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        if isinstance(batch["speech"], np.ndarray):
            batch["speech"] = torch.tensor(batch["speech"])
        if isinstance(batch["speech_lengths"], np.ndarray):
            batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, encoder_out_lens = self.asr_model.encode(**batch)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2
        )

        # It maybe useful to tune blank_bias.
        # The valid range of blank_bias is [-inf, 0]
        logp_encoder_output[:, :, 0] += self.blank_bias

        batch_size = encoder_out_lens.size(0)
        sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(torch.int32)
        start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t()
        num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32)
        supervision_segments = torch.cat([sequence_idx, start_frame, num_frames], dim=1)

        supervision_segments = supervision_segments.to(torch.int32)

        # An introduction to DenseFsaVec:
        # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector
        # It could be viewed as a fsa-type lopg_encoder_output,
        # whose weight on the arcs are initialized with logp_encoder_output.
        # The goal of converting tensor-type to fsa-type is using
        # fsa related functions in k2. e.g. k2.intersect_dense_pruned below
        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments)

        # The term "intersect" is similar to "compose" in k2.
        # The differences is are:
        # for "compose" functions, the composition involves
        # mathcing output label of a.fsa and input label of b.fsa
        # while for "intersect" functions, the composition involves
        # matching input label of a.fsa and input label of b.fsa
        # Actually, in compose functions, b.fsa is inverted and then
        # a.fsa and inv_b.fsa are intersected together.
        # For difference between compose and interset:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/fsa_algo.py#L308
        # For definition of k2.intersect_dense_pruned:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/autograd.py#L648
        lattices = k2.intersect_dense_pruned(
            self.decode_graph,
            dense_fsa_vec,
            self.search_beam_size,
            self.output_beam_size,
            self.min_active_states,
            self.max_active_states,
        )

        # lattices.scores is the sum of decode_graph.scores(a.k.a. lm weight) and
        # dense_fsa_vec.scores(a.k.a. am weight) on related arcs.
        # For ctc decoding graph, lattices.scores only store am weight
        # since the decoder_graph only define the ctc topology and
        # has no lm weight on its arcs.
        # While for 3-gram decoding, whose graph is converted from language models,
        # lattice.scores contains both am weights and lm weights
        #
        # It maybe useful to tune lattice.scores
        # The valid range of lattice_weight is [0, inf)
        # The lattice_weight will affect the search of k2.random_paths
        lattices.scores *= self.lattice_weight

        results = []
        if self.use_nbest_rescoring:
            (
                am_scores,
                lm_scores,
                token_ids,
                new2old,
                path_to_seq_map,
                seq_to_path_splits,
            ) = nbest_am_lm_scores(
                lattices, self.num_paths, self.device, self.nbest_batch_size
            )

            ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device)
            max_token_length = max(ys_pad_lens)
            ys_pad_list = []
            for hyp in token_ids:
                ys_pad_list.append(
                    torch.cat(
                        [
                            torch.tensor(hyp, dtype=torch.long),
                            torch.tensor(
                                [self.asr_model.ignore_id]
                                * (max_token_length.item() - len(hyp)),
                                dtype=torch.long,
                            ),
                        ]
                    )
                )

            ys_pad = (
                torch.stack(ys_pad_list).to(torch.long).to(self.device)
            )  # [batch, max_token_length]

            encoder_out = enc.index_select(0, path_to_seq_map.to(torch.long)).to(
                self.device
            )  # [batch, T, dim]
            encoder_out_lens = encoder_out_lens.index_select(
                0, path_to_seq_map.to(torch.long)
            ).to(
                self.device
            )  # [batch]

            decoder_scores = -self.asr_model.batchify_nll(
                encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size
            )

            # padded_value for nnlm is 0
            ys_pad[ys_pad == self.asr_model.ignore_id] = 0
            nnlm_nll, x_lengths = self.lm.batchify_nll(
                ys_pad, ys_pad_lens, self.nll_batch_size
            )
            nnlm_scores = -nnlm_nll.sum(dim=1)

            batch_tot_scores = (
                self.am_weight * am_scores
                + self.decoder_weight * decoder_scores
                + self.nnlm_weight * nnlm_scores
            )
            split_size = indices_to_split_size(
                seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0)
            )
            batch_tot_scores = torch.split(
                batch_tot_scores,
                split_size,
            )

            hyps = []
            scores = []
            processed_seqs = 0
            for tot_scores in batch_tot_scores:
                if tot_scores.nelement() == 0:
                    # the last element by torch.tensor_split may be empty
                    # e.g.
                    # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4]))
                    # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64))
                    break
                best_seq_idx = processed_seqs + torch.argmax(tot_scores)

                assert best_seq_idx < len(token_ids)
                best_token_seqs = token_ids[best_seq_idx]
                processed_seqs += tot_scores.nelement()
                hyps.append(best_token_seqs)
                scores.append(tot_scores.max().item())

            assert len(hyps) == len(split_size)
        else:
            best_paths = k2.shortest_path(lattices, use_double_scores=True)
            scores = best_paths.get_tot_scores(
                use_double_scores=True, log_semiring=False
            ).tolist()
            hyps = get_texts(best_paths)

        assert len(scores) == len(hyps)

        for token_int, score in zip(hyps, scores):
            # For decoding methods nbest_rescoring and ctc_decoding
            # hyps stores token_index, which is lattice.labels.

            # convert token_id to text with self.tokenizer
            token = self.converter.ids2tokens(token_int)
            assert self.tokenizer is not None
            text = self.tokenizer.tokens2text(token)
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Пример #12
0
def decode(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: Union[str, torch.device],
    HLG: Fsa,
    symbols: SymbolTable,
):
    num_batches = None
    try:
        num_batches = len(dataloader)
    except TypeError:
        pass
    num_cuts = 0
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch["inputs"]
        supervisions = batch["supervisions"]
        supervision_segments = torch.stack(
            (
                supervisions["sequence_idx"],
                torch.floor_divide(supervisions["start_frame"],
                                   model.subsampling_factor),
                torch.floor_divide(supervisions["num_frames"],
                                   model.subsampling_factor),
            ),
            1,
        ).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions["text"]
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        blank_bias = -3.0
        nnet_output[:, :, 0] += blank_bias

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert HLG.is_cuda()
        assert (
            HLG.device == nnet_output.device
        ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
                                             10000)

        # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == len(texts)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            hyp_words = [symbols.get(x) for x in hyps[i]]
            ref_words = texts[i].split(" ")
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}"
            logging.info(
                f"batch {batch_str}, number of cuts processed until now is {num_cuts}"
            )

        num_cuts += len(texts)

    return results
Пример #13
0
def get_loss(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiMbrTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    assert P.device == device
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num_graph, den_graph, decoding_graph = graph_compiler.compile(
                texts, P)

    assert num_graph.requires_grad == is_training
    assert den_graph.requires_grad is False
    assert decoding_graph.requires_grad is False
    assert len(
        decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None)

    num_graph = num_graph.to(device)
    den_graph = den_graph.to(device)

    decoding_graph = decoding_graph.to(device)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num_lats = k2.intersect_dense(num_graph,
                                  dense_fsa_vec,
                                  10.0,
                                  seqframe_idx_name='seqframe_idx')

    mbr_lats = k2.intersect_dense_pruned(decoding_graph,
                                         dense_fsa_vec,
                                         20.0,
                                         7.0,
                                         30,
                                         10000,
                                         seqframe_idx_name='seqframe_idx')

    if True:
        # WARNING: the else branch is not working at present (the total loss is not stable)
        den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0)
    else:
        # in this case, we can remove den_graph
        den_lats = mbr_lats

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    if id(den_lats) == id(mbr_lats):
        # Some entries in den_tot_scores may be -inf.
        # The corresponding sequences are discarded/ignored.
        finite_indexes = torch.isfinite(den_tot_scores)
        den_tot_scores = den_tot_scores[finite_indexes]
        num_tot_scores = num_tot_scores[finite_indexes]
    else:
        finite_indexes = None

    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2],
                                               finite_indexes)

    num_rows = dense_fsa_vec.scores.shape[0]
    num_cols = dense_fsa_vec.scores.shape[1] - 1
    mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx,
                                      cols=num_lats.phones,
                                      values=num_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)

    mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx,
                                      cols=mbr_lats.phones,
                                      values=mbr_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)
    # NOTE: Due to limited support of PyTorch's autograd for sparse tensors,
    # we cannot use (mbr_num_sparse - mbr_den_sparse) here
    #
    # The following works only for torch >= 1.7.0
    mbr_loss = torch.sparse.sum(
        k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce()))

    mmi_loss = -tot_score

    total_loss = mmi_loss + mbr_loss

    if is_training:
        optimizer.zero_grad()
        total_loss.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = (
        mmi_loss.detach().cpu().item(),
        mbr_loss.detach().cpu().item(),
        tot_frames.cpu().item(),
        all_frames.cpu().item(),
    )
    return ans
Пример #14
0
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            data: Input speech data
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)

        # data: (Nsamples,) -> (1, Nsamples)
        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        # lenghts: (1,)
        lengths = speech.new_full([1],
                                  dtype=torch.long,
                                  fill_value=speech.size(1))
        batch = {"speech": speech, "speech_lengths": lengths}

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, _ = self.asr_model.encode(**batch)
        assert len(enc) == 1, len(enc)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2)

        # TODO(Liyong Guo): Support batch decoding.
        # Following statement only support batch_size == 1
        supervision_segments = torch.tensor([[0, 0, enc.shape[1]]],
                                            dtype=torch.int32)
        indices = torch.tensor([0])

        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output,
                                       supervision_segments)

        lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec,
                                             20.0, self.output_beam_size, 30,
                                             10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        scores = best_paths.get_tot_scores(use_double_scores=True,
                                           log_semiring=False).tolist()

        hyps = get_texts(best_paths, indices)
        # TODO(Liyong Guo): Support batch decoding. now batch_size == 1.
        assert len(scores) == 1
        assert len(scores) == len(hyps)

        results = []

        for token_int, score in zip(hyps, scores):
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Пример #15
0
    def test_two_fsas(self):
        s1 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        fsa1 = k2.Fsa.from_str(s1)
        fsa2 = k2.Fsa.from_str(s2)

        fsa1.requires_grad_(True)
        fsa2.requires_grad_(True)

        fsa_vec = k2.create_fsa_vec([fsa1, fsa2])

        log_prob = torch.tensor(
            [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
             [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
            dtype=torch.float32,
            requires_grad=True)

        supervision_segments = torch.tensor([[0, 0, 2], [1, 0, 3]],
                                            dtype=torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
        out_fsa = k2.intersect_dense_pruned(fsa_vec,
                                            dense_fsa_vec,
                                            search_beam=100000,
                                            output_beam=100000,
                                            min_active_states=0,
                                            max_active_states=10000)
        assert out_fsa.shape == (2, None, None), 'There should be two FSAs!'

        scores = k2.get_tot_scores(out_fsa,
                                   log_semiring=False,
                                   use_float_scores=True)
        scores.sum().backward()

        # `expected` results are computed using gtn.
        # See https://bit.ly/3oYObeb
        expected_scores_out_fsa = torch.tensor(
            [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0])

        expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0])
        expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0, 1.0])
        print("fsa2 is ", fsa2.__str__())
        expected_grad_log_prob = torch.tensor([
            0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0,
            0.0, 0.0, 0.0, 1.0
        ]).reshape_as(log_prob)

        assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
        assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad)
        assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad)
        assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Пример #16
0
    def decode(
        self,
        log_probs: torch.Tensor,
        log_probs_length: torch.Tensor,
        return_lattices: bool = False,
        return_ilabels: bool = False,
        output_aligned: bool = True,
    ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]:
        if self.decoding_graph is None:
            self.decoding_graph = self.base_graph

        if self.blank != 0:
            # rearrange log_probs to put blank at the first place
            # and shift targets to emulate blank = 0
            log_probs, _ = make_blank_first(self.blank, log_probs, None)
        supervisions, order = create_supervision(log_probs_length)
        if self.decoding_graph.shape[0] > 1:
            self.decoding_graph = k2.index_fsa(self.decoding_graph, order).to(device=log_probs.device)

        if log_probs.device != self.device:
            self.to(log_probs.device)
        dense_fsa_vec = (
            prep_padded_densefsavec(log_probs, supervisions)
            if self.pad_fsavec
            else k2.DenseFsaVec(log_probs, supervisions)
        )

        if self.intersect_pruned:
            lats = k2.intersect_dense_pruned(
                a_fsas=self.decoding_graph,
                b_fsas=dense_fsa_vec,
                search_beam=self.intersect_conf.search_beam,
                output_beam=self.intersect_conf.output_beam,
                min_active_states=self.intersect_conf.min_active_states,
                max_active_states=self.intersect_conf.max_active_states,
            )
        else:
            indices = torch.zeros(dense_fsa_vec.dim0(), dtype=torch.int32, device=self.device)
            dec_graphs = (
                k2.index_fsa(self.decoding_graph, indices)
                if self.decoding_graph.shape[0] == 1
                else self.decoding_graph
            )
            lats = k2.intersect_dense(dec_graphs, dense_fsa_vec, self.intersect_conf.output_beam)
        if self.pad_fsavec:
            shift_labels_inpl([lats], -1)
        self.decoding_graph = None

        if return_lattices:
            lats = k2.index_fsa(lats, invert_permutation(order).to(device=log_probs.device))
            if self.blank != 0:
                # change only ilabels
                # suppose self.blank == self.num_classes - 1
                lats.labels = torch.where(lats.labels == 0, self.blank, lats.labels - 1)
            return lats
        else:
            shortest_path_fsas = k2.index_fsa(
                k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device),
            )
            shortest_paths = []
            probs = []
            # direct iterating does not work as expected
            for i in range(shortest_path_fsas.shape[0]):
                shortest_path_fsa = shortest_path_fsas[i]
                labels = (
                    shortest_path_fsa.labels[:-1].to(dtype=torch.long)
                    if return_ilabels
                    else shortest_path_fsa.aux_labels[:-1].to(dtype=torch.long)
                )
                if self.blank != 0:
                    # suppose self.blank == self.num_classes - 1
                    labels = torch.where(labels == 0, self.blank, labels - 1)
                if not return_ilabels and not output_aligned:
                    labels = labels[labels != self.blank]
                shortest_paths.append(labels[::2] if self.pad_fsavec else labels)
                probs.append(get_arc_weights(shortest_path_fsa)[:-1].to(device=log_probs.device).exp())
            return shortest_paths, probs
def decode_one_batch(batch: Dict[str, Any],
                     model: AcousticModel,
                     HLG: k2.Fsa,
                     output_beam_size: float,
                     num_paths: int,
                     use_whole_lattice: bool,
                     G: Optional[k2.Fsa] = None) -> Dict[str, List[List[int]]]:
    '''
    Decode one batch and return the result in a dict. The dict has the
    following format:

        - key: It indicates the setting used for decoding. For example,
               if no rescoring is used, the key is the string `no_rescore`.
               If LM rescoring is used, the key is the string `lm_scale_xxx`,
               where `xxx` is the value of `lm_scale`. An example key is
               `lm_scale_0.7`
        - value: It contains the decoding result. `len(value)` equals to
                 batch size. `value[i]` is the decoding result for the i-th
                 utterance in the given batch.

    Args:
      batch:
        It is the return value from iterating
        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
        for the format of the `batch`.
      model:
        The neural network model.
      HLG:
        The decoding graph.
      output_beam_size:
        Size of the beam for pruning.
      use_whole_lattice:
        If True, `G` must not be None and it will use whole lattice for
        LM rescoring.
        If False and if `G` is not None, then `num_paths` must be positive
        and it will use n-best list for LM rescoring.
      num_paths:
        It specifies the size of `n` in n-best list decoding.
      G:
        The LM. If it is None, no rescoring is used.
        Otherwise, LM rescoring is used.
        It supports two types of LM rescoring: n-best list rescoring
        and whole lattice rescoring.
        `use_whole_lattice` specifies which type to use.

    Returns:
      Return the decoding result. See above description for the format of
      the returned dict.
    '''
    device = HLG.device
    feature = batch['inputs']
    assert feature.ndim == 3
    feature = feature.to(device)

    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]

    supervisions = batch['supervisions']

    nnet_output, _, _ = model(feature, supervisions)
    # nnet_output is [N, C, T]

    nnet_output = nnet_output.permute(0, 2, 1)
    # now nnet_output is [N, T, C]

    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
         (((supervisions['num_frames'] - 1) // 2 - 1) // 2)),
        1).to(torch.int32)

    supervision_segments = torch.clamp(supervision_segments, min=0)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0,
                                         output_beam_size, 30, 10000)

    if G is None:
        if num_paths > 1:
            best_paths = nbest_decoding(lattices, num_paths)
            key = f'no_rescore-{num_paths}'
        else:
            key = 'no_rescore'
            best_paths = k2.shortest_path(lattices, use_double_scores=True)
        hyps = get_texts(best_paths, indices)
        return {key: hyps}

    lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]

    if use_whole_lattice:
        best_paths_dict = rescore_with_whole_lattice(lattices, G,
                                                     lm_scale_list)
    else:
        best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths,
                                                   lm_scale_list)
    # best_paths_dict is a dict
    #  - key: lm_scale_xxx, where xxx is the value of lm_scale. An example
    #         key is lm_scale_1.2
    #  - value: it is the best path obtained using the corresponding lm scale
    #           from the dict key.

    ans = dict()
    for lm_scale_str, best_paths in best_paths_dict.items():
        hyps = get_texts(best_paths, indices)
        ans[lm_scale_str] = hyps
    return ans
Пример #18
0
    def __call__(
        self, batch: Dict[str, Union[torch.Tensor, np.ndarray]]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            batch: Input speech data and corresponding lengths
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        if isinstance(batch["speech"], np.ndarray):
            batch["speech"] = torch.tensor(batch["speech"])
        if isinstance(batch["speech_lengths"], np.ndarray):
            batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, encoder_out_lens = self.asr_model.encode(**batch)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2)

        batch_size = encoder_out_lens.size(0)
        sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(
            torch.int32)
        start_frame = torch.zeros([batch_size],
                                  dtype=torch.int32).unsqueeze(0).t()
        num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32)
        supervision_segments = torch.cat(
            [sequence_idx, start_frame, num_frames], dim=1)

        supervision_segments = supervision_segments.to(torch.int32)

        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output,
                                       supervision_segments)

        lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec,
                                             20.0, self.output_beam_size, 30,
                                             10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        scores = best_paths.get_tot_scores(use_double_scores=True,
                                           log_semiring=False).tolist()

        hyps = get_texts(best_paths)
        assert len(scores) == len(hyps)

        results = []

        for token_int, score in zip(hyps, scores):
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Пример #19
0
def decode(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: Union[str, torch.device],
    HCLG: Fsa,
):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []  # a list of pair [ref_labels, hyp_labels]
    for batch_idx, batch in enumerate(dataloader):
        feature = batch["inputs"]  # (N, T, C)
        supervisions = batch["supervisions"]

        feature = feature.to(device)

        # Since we are decoding with a k2 graph here, we need to create appropriate
        # supervisions. The segments need to be ordered in decreasing order of
        # length (although in our case all segments are of same length)
        supervision_segments = torch.stack(
            (
                supervisions["sequence_idx"],
                torch.floor_divide(supervisions["start_frame"],
                                   model.subsampling_factor),
                torch.floor_divide(supervisions["duration"],
                                   model.subsampling_factor),
            ),
            1,
        ).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]

        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)

        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert HLG.is_cuda()
        assert (
            HCLG.device == nnet_output.device
        ), f"Check failed: HCLG.device ({HCLG.device}) == nnet_output.device ({nnet_output.device})"

        lattices = k2.intersect_dense_pruned(HCLG, dense_fsa_vec, 20.0, 7.0,
                                             30, 10000)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == supervisions["is_voice"].shape[0]

        # best_paths is an FsaVec, and each of its FSAs is a linear FSA
        references = supervisions["is_voice"][indices]
        for i in range(references.shape[0]):
            ref = references[i, :]
            hyp = k2.arc_sort(
                best_paths[i]).arcs_as_tensor()[:-1, 2].detach().cpu()
            assert (
                ref.shape[0] == hyp.shape[0]
            ), "reference and hypothesis have unequal number of frames, {} vs. {}".format(
                ref.shape[0], hyp.shape[0])
            results.append((supervisions["cut"][indices[i]], ref, hyp))

        if batch_idx % 10 == 0:
            logging.info(
                "batch {}, cuts processed until now is {}/{} ({:.6f}%)".format(
                    batch_idx,
                    num_cuts,
                    tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100,
                ))

        num_cuts += supervisions["is_voice"].shape[0]

    return results