Ejemplo n.º 1
0
    def test_single_fsa(self):
        s = '''
            0 4 1 1
            0 1 1 1
            1 2 1 2
            1 3 1 3
            2 7 1 4
            3 7 1 5
            4 6 1 2
            4 8 1 3
            5 9 -1 4
            6 9 -1 3
            7 9 -1 5
            8 9 -1 6
            9
        '''
        for device in self.devices:
            fsa = k2.Fsa.from_str(s).to(device)
            fsa = k2.create_fsa_vec([fsa])
            fsa.requires_grad_(True)
            best_path = k2.shortest_path(fsa, use_double_scores=False)

            # we recompute the total_scores for backprop
            total_scores = best_path.scores.sum()

            assert total_scores == 14
            expected = torch.zeros(12)
            expected[torch.tensor([1, 3, 5, 10])] = 1
            total_scores.backward()
            assert torch.allclose(fsa.scores.grad, expected.to(device))
Ejemplo n.º 2
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))

    transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys])
    transcripts_lexicon = k2.intersect(transcripts, L_inv)
    transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon))
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Ejemplo n.º 3
0
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch['inputs']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             torch.floor_divide(supervisions['start_frame'],
                                model.subsampling_factor),
             torch.floor_divide(supervisions['num_frames'],
                                model.subsampling_factor)), 1).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        #  blank_bias = -3.0
        #  nnet_output[:, :, 0] += blank_bias

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert LG.is_cuda()
        assert HLG.device == nnet_output.device, \
            f"Check failed: LG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
                                             10000)

        # lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == len(texts)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            hyp_words = [symbols.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
                    batch_idx, num_cuts, tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100))

        num_cuts += len(texts)

    return results
Ejemplo n.º 4
0
def decode(dataloader: torch.utils.data.DataLoader,
           model: None,
           device: Union[str, torch.device],
           ctc_topo: None,
           numericalizer=None,
           num_paths=-1,
           output_beam_size: float=8):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []
    for batch_idx, batch in enumerate(dataloader):
        assert isinstance(batch, dict), type(batch)
        feature = batch['inputs']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
             (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32)
        supervision_segments = torch.clamp(supervision_segments, min=0)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        nnet_output = nnet_output.permute(0, 2, 1)

        # TODO(Liyong Guo): Tune this bias
        # blank_bias = 0.0
        # nnet_output[:, :, 0] += blank_bias

        with torch.no_grad():
            dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

            lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0,
                                                 output_beam_size, 30, 10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            pieces = [numericalizer.tokens_list[token_id] for token_id in hyps[i]]
            hyp_words = numericalizer.tokenizer.DecodePieces(pieces).split(' ')
            ref_words = texts[i].split(' ')
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
                    batch_idx, num_cuts, tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100))
        num_cuts += len(texts)
    return results
Ejemplo n.º 5
0
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable):
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch['features']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             torch.floor_divide(supervisions['start_frame'],
                                model.subsampling_factor),
             torch.floor_divide(supervisions['num_frames'],
                                model.subsampling_factor)), 1).to(torch.int32)
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        assert LG.is_cuda()
        assert LG.device == nnet_output.device, \
            f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0,
                                             30, 300)
        best_paths = k2.shortest_path(lattices, use_float_scores=True)
        best_paths = best_paths.to('cpu')
        assert best_paths.shape[0] == len(texts)

        for i in range(len(texts)):
            hyp_words = [
                symbols.get(x) for x in best_paths[i].aux_labels if x > 0
            ]
            results.append((texts[i].split(' '), hyp_words))

        if batch_idx % 10 == 0:
            logging.info('Processed batch {}/{} ({:.6f}%)'.format(
                batch_idx, len(dataloader),
                float(batch_idx) / len(dataloader) * 100))

    return results
Ejemplo n.º 6
0
    def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        """
        Perform forced alignment and return a tensor that represents a batch of frame-level alignments:
        >>> alignments = torch.tensor([
        ...     [0, 0, 0, 1, 57, 57, 35, 35, 35, ...],
        ...     [...],
        ...     ...
        ... ])

        :return: an int32 tensor with shape ``(batch_size, num_frames)``.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'

        cuts = cuts.map_supervisions(self.normalize_text)

        otf = OnTheFlyFeatures(self.extractor)
        feats, _ = otf(cuts)
        feats = feats.permute(0, 2, 1)
        texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts]

        # Compute AM posteriors
        # (batch, seq_len ~/ 4, num_phones)
        posteriors, _, _ = self.model(feats)
        # Note: we are using "dummy" supervisions so that the aligner also considers
        # the padding area. We can adjust that behaviour if needed by passing actual
        # supervision segments, but then we will have a ragged tensor (will need to
        # pad the alignments themselves).
        sups = self.dummy_supervisions(feats)
        posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups)

        # Intersection with ground truth transcript graphs
        num, den = self.compiler.compile(texts, self.P)
        alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0)
        best_path = k2.shortest_path(alignment, use_double_scores=True)

        # Retrieve sequences of phone IDs per frame
        # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels)
        frame_labels = torch.stack(
            [best_path[i].labels[:-1] for i in range(best_path.shape[0])])
        return frame_labels
Ejemplo n.º 7
0
    def decode(
            self, cuts: Union[AnyCut,
                              CutSet]) -> List[Tuple[List[str], List[str]]]:
        """
        Perform decoding with an n-gram language model (HLG graph).
        Doesn't support rescoring at this time.
        """
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        word_results = []
        # Hacky way to get batch quickly... we may need to improve on this.
        batch = K2SpeechRecognitionDataset(cuts,
                                           input_strategy=OnTheFlyFeatures(
                                               self.extractor),
                                           check_inputs=False)[list(cuts.ids)]
        features = batch['inputs'].permute(0, 2, 1).to(
            self.device)  # (B, T, F) -> (B, F, T)
        supervision_segments, texts = encode_supervisions(
            batch['supervisions'])

        # Forward pass through the acoustic model
        posteriors, _, _ = self.model(features)
        posteriors = posteriors.permute(0, 2, 1)  # (B, F, T) -> (B, T, F)

        # Wrapping into k2 "dense FSA" (representing PPG as a dense graph)
        dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments)

        # The actual decoding starts here:
        # First, we intersect the HLG and the PPG
        # with default pruning/beam search params from snowfall
        # The result is a batch of graphs (lattices)
        lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8,
                                             30, 10000)
        # ... then we find the shortest paths in the lattices ...
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        # ... and convert them to words with a convenience wrapper from snowfall
        hyps = get_texts(best_paths, torch.arange(len(texts)))

        # Here we read out the words from the best path graphs
        for i in range(len(texts)):
            hyp_words = [self.lexicon.words.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            word_results.append((ref_words, hyp_words))
        return word_results
Ejemplo n.º 8
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))
    device = L_inv.device

    transcripts = k2.create_fsa_vec(
        [k2.linear_fsa(x, device=device) for x in ys])
    transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)

    transcripts_lexicon = k2.intersect(L_inv,
                                       transcripts_with_self_loops,
                                       treat_epsilons_specially=False)
    # Don't call invert_() above because we want to return phone IDs,
    # which is the `aux_labels` of transcripts_lexicon
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.top_sort(transcripts_lexicon)

    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Ejemplo n.º 9
0
Archivo: nbest.py Proyecto: k2-fsa/k2
    def intersect(self, lats: Fsa) -> 'Nbest':
        '''Intersect this Nbest object with a lattice and get 1-best
        path from the resulting FsaVec.

        Caution:
          We assume FSAs in `self.fsa` don't have epsilon self-loops.
          We also assume `self.fsa.labels` and `lats.labels` are token IDs.

        Args:
          lats:
            An FsaVec. It can be the return value of
            :func:`whole_lattice_rescoring`.
        Returns:
          Return a new Nbest. This new Nbest shares the same shape with `self`,
          while its `fsa` is the 1-best path from intersecting `self.fsa` and
          `lats.
        '''
        assert self.fsa.device == lats.device, \
                f'{self.fsa.device} vs {lats.device}'
        assert len(lats.shape) == 3, f'{lats.shape}'
        assert lats.arcs.dim0() == self.shape.dim0(), \
                f'{lats.arcs.dim0()} vs {self.shape.dim0()}'

        lats = k2.arc_sort(lats)  # no-op if lats is already arc sorted

        fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa)

        path_to_seq_map = self.shape.row_ids(1)

        ans_lats = k2.intersect_device(a_fsas=lats,
                                       b_fsas=fsas_with_epsilon_loops,
                                       b_to_a_map=path_to_seq_map,
                                       sorted_match_a=True)

        one_best = k2.shortest_path(ans_lats, use_double_scores=True)

        one_best = k2.remove_epsilon(one_best)

        return Nbest(fsa=one_best, shape=self.shape)
Ejemplo n.º 10
0
    def __call__(
        self, batch: Dict[str, Union[torch.Tensor, np.ndarray]]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            batch: Input speech data and corresponding lengths
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        if isinstance(batch["speech"], np.ndarray):
            batch["speech"] = torch.tensor(batch["speech"])
        if isinstance(batch["speech_lengths"], np.ndarray):
            batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, encoder_out_lens = self.asr_model.encode(**batch)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2)

        batch_size = encoder_out_lens.size(0)
        sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(
            torch.int32)
        start_frame = torch.zeros([batch_size],
                                  dtype=torch.int32).unsqueeze(0).t()
        num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32)
        supervision_segments = torch.cat(
            [sequence_idx, start_frame, num_frames], dim=1)

        supervision_segments = supervision_segments.to(torch.int32)

        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output,
                                       supervision_segments)

        lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec,
                                             20.0, self.output_beam_size, 30,
                                             10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        scores = best_paths.get_tot_scores(use_double_scores=True,
                                           log_semiring=False).tolist()

        hyps = get_texts(best_paths)
        assert len(scores) == len(hyps)

        results = []

        for token_int, score in zip(hyps, scores):
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Ejemplo n.º 11
0
def decode(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: Union[str, torch.device],
    HLG: Fsa,
    symbols: SymbolTable,
):
    num_batches = None
    try:
        num_batches = len(dataloader)
    except TypeError:
        pass
    num_cuts = 0
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch["inputs"]
        supervisions = batch["supervisions"]
        supervision_segments = torch.stack(
            (
                supervisions["sequence_idx"],
                torch.floor_divide(supervisions["start_frame"],
                                   model.subsampling_factor),
                torch.floor_divide(supervisions["num_frames"],
                                   model.subsampling_factor),
            ),
            1,
        ).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions["text"]
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        blank_bias = -3.0
        nnet_output[:, :, 0] += blank_bias

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert HLG.is_cuda()
        assert (
            HLG.device == nnet_output.device
        ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
                                             10000)

        # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == len(texts)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            hyp_words = [symbols.get(x) for x in hyps[i]]
            ref_words = texts[i].split(" ")
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}"
            logging.info(
                f"batch {batch_str}, number of cuts processed until now is {num_cuts}"
            )

        num_cuts += len(texts)

    return results
Ejemplo n.º 12
0
def rescore_with_whole_lattice(lats: k2.Fsa,
                               G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]
    inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats)

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(
            inverted_lats)
        print('num_arcs after pruning: ',
              inverted_lats_with_epsilon_loops.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats_with_epsilon_loops,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(
        rescoring_lats.to('cpu'))).to(device)
    inverted_rescoring_lats = k2.invert(rescoring_lats)
    # inverted rescoring_lats has phone IDs as labels
    # and word IDs as aux_labels.

    inverted_rescoring_lats = k2.remove_epsilon_self_loops(
        inverted_rescoring_lats)
    best_paths = k2.shortest_path(inverted_rescoring_lats,
                                  use_double_scores=True)
    return best_paths
Ejemplo n.º 13
0
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa,
                               lm_scale_list: List[float]
                              ) -> Dict[str, k2.Fsa]:
    '''Use whole lattice to rescore.

    Args:
      lats:
        An FsaVec It can be the output of `k2.intersect_dense_pruned`.
      G_with_epsilon_loops:
        An FsaVec representing the language model (LM). Note that it
        is an FsaVec, but it contains only one Fsa.
      lm_scale_list:
        A list containing lm_scale values.
    Returns:
      A dict of FsaVec, whose key is a lm_scale and the value represents the
      best decoding path for each sequence in the lattice.
    '''
    assert len(lats.shape) == 3
    assert hasattr(lats, 'lm_scores')
    assert G_with_epsilon_loops.shape == (1, None, None)

    device = lats.device
    lats.scores = lats.scores - lats.lm_scores
    # We will use lm_scores from G, so remove lats.lm_scores here
    del lats.lm_scores
    assert hasattr(lats, 'lm_scores') is False

    #  lats.scores = scores / lm_scale
    # Now, lats.scores contains only am_scores

    # inverted_lats has word IDs as labels.
    # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt
    inverted_lats = k2.invert(lats)
    num_seqs = lats.shape[0]

    b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
    try:
        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)
    except RuntimeError as e:
        print(f'Caught exception:\n{e}\n')
        print(f'Number of FSAs: {inverted_lats.shape[0]}')
        print('num_arcs before pruning: ', inverted_lats.arcs.num_elements())

        # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here
        # to avoid OOM. We may need to fine tune it.
        inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True)
        print('num_arcs after pruning: ', inverted_lats.arcs.num_elements())

        rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
                                             inverted_lats,
                                             b_to_a_map,
                                             sorted_match_a=True)

    rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device))

    # inv_lats has phone IDs as labels
    # and word IDs as aux_labels.
    inv_lats = k2.invert(rescoring_lats)

    ans = dict()
    #
    # The following implements
    # scores = (scores - lm_scores)/lm_scale + lm_scores
    #        = scores/lm_scale + lm_scores*(1 - 1/lm_scale)
    #
    saved_scores = inv_lats.scores.clone()
    for lm_scale in lm_scale_list:
        am_scores = saved_scores - inv_lats.lm_scores
        am_scores /= lm_scale
        inv_lats.scores = am_scores + inv_lats.lm_scores

        best_paths = k2.shortest_path(inv_lats, use_double_scores=True)
        key = f'lm_scale_{lm_scale}'
        ans[key] = best_paths
    return ans
def decode_one_batch(batch: Dict[str, Any],
                     model: AcousticModel,
                     HLG: k2.Fsa,
                     output_beam_size: float,
                     num_paths: int,
                     use_whole_lattice: bool,
                     G: Optional[k2.Fsa] = None) -> Dict[str, List[List[int]]]:
    '''
    Decode one batch and return the result in a dict. The dict has the
    following format:

        - key: It indicates the setting used for decoding. For example,
               if no rescoring is used, the key is the string `no_rescore`.
               If LM rescoring is used, the key is the string `lm_scale_xxx`,
               where `xxx` is the value of `lm_scale`. An example key is
               `lm_scale_0.7`
        - value: It contains the decoding result. `len(value)` equals to
                 batch size. `value[i]` is the decoding result for the i-th
                 utterance in the given batch.

    Args:
      batch:
        It is the return value from iterating
        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
        for the format of the `batch`.
      model:
        The neural network model.
      HLG:
        The decoding graph.
      output_beam_size:
        Size of the beam for pruning.
      use_whole_lattice:
        If True, `G` must not be None and it will use whole lattice for
        LM rescoring.
        If False and if `G` is not None, then `num_paths` must be positive
        and it will use n-best list for LM rescoring.
      num_paths:
        It specifies the size of `n` in n-best list decoding.
      G:
        The LM. If it is None, no rescoring is used.
        Otherwise, LM rescoring is used.
        It supports two types of LM rescoring: n-best list rescoring
        and whole lattice rescoring.
        `use_whole_lattice` specifies which type to use.

    Returns:
      Return the decoding result. See above description for the format of
      the returned dict.
    '''
    device = HLG.device
    feature = batch['inputs']
    assert feature.ndim == 3
    feature = feature.to(device)

    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]

    supervisions = batch['supervisions']

    nnet_output, _, _ = model(feature, supervisions)
    # nnet_output is [N, C, T]

    nnet_output = nnet_output.permute(0, 2, 1)
    # now nnet_output is [N, T, C]

    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
         (((supervisions['num_frames'] - 1) // 2 - 1) // 2)),
        1).to(torch.int32)

    supervision_segments = torch.clamp(supervision_segments, min=0)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0,
                                         output_beam_size, 30, 10000)

    if G is None:
        if num_paths > 1:
            best_paths = nbest_decoding(lattices, num_paths)
            key = f'no_rescore-{num_paths}'
        else:
            key = 'no_rescore'
            best_paths = k2.shortest_path(lattices, use_double_scores=True)
        hyps = get_texts(best_paths, indices)
        return {key: hyps}

    lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]

    if use_whole_lattice:
        best_paths_dict = rescore_with_whole_lattice(lattices, G,
                                                     lm_scale_list)
    else:
        best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths,
                                                   lm_scale_list)
    # best_paths_dict is a dict
    #  - key: lm_scale_xxx, where xxx is the value of lm_scale. An example
    #         key is lm_scale_1.2
    #  - value: it is the best path obtained using the corresponding lm scale
    #           from the dict key.

    ans = dict()
    for lm_scale_str, best_paths in best_paths_dict.items():
        hyps = get_texts(best_paths, indices)
        ans[lm_scale_str] = hyps
    return ans
Ejemplo n.º 15
0
def levenshtein_alignment(
        refs: Fsa,
        hyps: Fsa,
        hyp_to_ref_map: torch.Tensor,
        sorted_match_ref: bool = False,
) -> Fsa:
    '''Get the levenshtein alignment of two FsaVecs

    This function supports both CPU and GPU. But it is very slow on CPU.

    Args:
      refs:
        An FsaVec (must have 3 axes, i.e., `len(refs.shape) == 3`. It is the
        output Fsa of the :func:`levenshtein_graph`.
      hyps:
        An FsaVec (must have 3 axes) on the same device as `refs`. It is the
        output Fsa of the :func:`levenshtein_graph`.
      hyp_to_ref_map:
        A 1-D torch.Tensor with dtype torch.int32 on the same device
        as `refs`. Map from FSA-id in `hpys` to the corresponding
        FSA-id in `refs` that we want to get levenshtein alignment with.
        E.g. might be an identity map, or all-to-zero, or something the
        user chooses.

        Requires
            - `hyp_to_ref_map.shape[0] == hyps.shape[0]`
            - `0 <= hyp_to_ref_map[i] < refs.shape[0]`
      sorted_match_ref:
        If true, the arcs of refs must be sorted by label (checked by
        calling code via properties), and we'll use a matching approach
        that requires this.

    Returns:
      Returns an FsaVec containing the alignment information and satisfing
      `ans.Dim0() == hyps.Dim0()`. Two attributes named `ref_labels` and
      `hyp_labels` will be added to the returned FsaVec. `ref_labels` contains
      the aligned sequences of refs and `hyp_labels` contains the aligned
      sequences of hyps. You can get the levenshtein distance by calling
      `get_tot_scores` on the returned FsaVec.

    Examples:
      >>> hyps = k2.levenshtein_graph([[1, 2, 3], [1, 3, 3, 2]])
      >>> refs = k2.levenshtein_graph([[1, 2, 4]])
      >>> alignment = k2.levenshtein_alignment(
              refs, hyps,
              hyp_to_ref_map=torch.tensor([0, 0], dtype=torch.int32),
              sorted_match_ref=True)
      >>> alignment.labels
      tensor([ 1,  2,  0, -1,  1,  0,  0,  0, -1], dtype=torch.int32)
      >>> alignment.ref_labels
      tensor([ 1,  2,  4, -1,  1,  2,  4,  0, -1], dtype=torch.int32)
      >>> alignment.hyp_labels
      tensor([ 1,  2,  3, -1,  1,  3,  3,  2, -1], dtype=torch.int32)
      >>> -alignment.get_tot_scores(
              use_double_scores=False, log_semiring=False))
      tensor([1., 3.])
    '''
    assert hasattr(refs, "aux_labels")
    assert hasattr(hyps, "aux_labels")

    hyps.rename_tensor_attribute_("aux_labels", "hyp_labels")

    lattice = k2.intersect_device(
        refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref)
    lattice = k2.remove_epsilon_self_loops(lattice)

    alignment = k2.shortest_path(lattice, use_double_scores=True).invert_()
    alignment.rename_tensor_attribute_("labels", "ref_labels")
    alignment.rename_tensor_attribute_("aux_labels", "labels")

    alignment.scores -= getattr(
        alignment, "__ins_del_score_offset_internal_attr_")

    return alignment
Ejemplo n.º 16
0
    def decode(
        self,
        log_probs: torch.Tensor,
        log_probs_length: torch.Tensor,
        return_lattices: bool = False,
        return_ilabels: bool = False,
        output_aligned: bool = True,
    ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]:
        if self.decoding_graph is None:
            self.decoding_graph = self.base_graph

        if self.blank != 0:
            # rearrange log_probs to put blank at the first place
            # and shift targets to emulate blank = 0
            log_probs, _ = make_blank_first(self.blank, log_probs, None)
        supervisions, order = create_supervision(log_probs_length)
        if self.decoding_graph.shape[0] > 1:
            self.decoding_graph = k2.index_fsa(self.decoding_graph, order).to(device=log_probs.device)

        if log_probs.device != self.device:
            self.to(log_probs.device)
        dense_fsa_vec = (
            prep_padded_densefsavec(log_probs, supervisions)
            if self.pad_fsavec
            else k2.DenseFsaVec(log_probs, supervisions)
        )

        if self.intersect_pruned:
            lats = k2.intersect_dense_pruned(
                a_fsas=self.decoding_graph,
                b_fsas=dense_fsa_vec,
                search_beam=self.intersect_conf.search_beam,
                output_beam=self.intersect_conf.output_beam,
                min_active_states=self.intersect_conf.min_active_states,
                max_active_states=self.intersect_conf.max_active_states,
            )
        else:
            indices = torch.zeros(dense_fsa_vec.dim0(), dtype=torch.int32, device=self.device)
            dec_graphs = (
                k2.index_fsa(self.decoding_graph, indices)
                if self.decoding_graph.shape[0] == 1
                else self.decoding_graph
            )
            lats = k2.intersect_dense(dec_graphs, dense_fsa_vec, self.intersect_conf.output_beam)
        if self.pad_fsavec:
            shift_labels_inpl([lats], -1)
        self.decoding_graph = None

        if return_lattices:
            lats = k2.index_fsa(lats, invert_permutation(order).to(device=log_probs.device))
            if self.blank != 0:
                # change only ilabels
                # suppose self.blank == self.num_classes - 1
                lats.labels = torch.where(lats.labels == 0, self.blank, lats.labels - 1)
            return lats
        else:
            shortest_path_fsas = k2.index_fsa(
                k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device),
            )
            shortest_paths = []
            probs = []
            # direct iterating does not work as expected
            for i in range(shortest_path_fsas.shape[0]):
                shortest_path_fsa = shortest_path_fsas[i]
                labels = (
                    shortest_path_fsa.labels[:-1].to(dtype=torch.long)
                    if return_ilabels
                    else shortest_path_fsa.aux_labels[:-1].to(dtype=torch.long)
                )
                if self.blank != 0:
                    # suppose self.blank == self.num_classes - 1
                    labels = torch.where(labels == 0, self.blank, labels - 1)
                if not return_ilabels and not output_aligned:
                    labels = labels[labels != self.blank]
                shortest_paths.append(labels[::2] if self.pad_fsavec else labels)
                probs.append(get_arc_weights(shortest_path_fsa)[:-1].to(device=log_probs.device).exp())
            return shortest_paths, probs
Ejemplo n.º 17
0
    def test_fsa_vec(self):
        # best path:
        #  states: 0 -> 1 -> 3 -> 7 -> 9
        #  arcs:     1 -> 3 -> 5 -> 10
        s1 = '''
            0 4 1 1
            0 1 1 1
            1 2 1 2
            1 3 1 3
            2 7 1 4
            3 7 1 5
            4 6 1 2
            4 8 1 3
            5 9 -1 4
            6 9 -1 3
            7 9 -1 5
            8 9 -1 6
            9
        '''

        #  best path:
        #   states: 0 -> 2 -> 3 -> 4 -> 5
        #   arcs:     1 -> 4 -> 5 -> 7
        s2 = '''
            0 1 1 1
            0 2 2 6
            1 2 3 3
            1 3 4 2
            2 3 5 4
            3 4 6 3
            3 5 -1 2
            4 5 -1 0
            5
        '''

        #  best path:
        #   states: 0 -> 2 -> 3
        #   arcs:     1 -> 3
        s3 = '''
            0 1 1 10
            0 2 2 100
            1 3 -1 3.5
            2 3 -1 5.5
            3
        '''

        for device in self.devices:
            fsa1 = k2.Fsa.from_str(s1).to(device)
            fsa2 = k2.Fsa.from_str(s2).to(device)
            fsa3 = k2.Fsa.from_str(s3).to(device)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)
            fsa3.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2, fsa3])
            assert fsa_vec.shape == (3, None, None)

            best_path = k2.shortest_path(fsa_vec, use_double_scores=False)

            # we recompute the total_scores for backprop
            total_scores = best_path.scores.sum()
            total_scores.backward()

            fsa1_best_arc_indexes = torch.tensor([1, 3, 5, 10], device=device)
            assert torch.all(
                torch.eq(fsa1.scores.grad[fsa1_best_arc_indexes],
                         torch.ones(4, device=device)))
            assert fsa1.scores.grad.sum() == 4

            fsa2_best_arc_indexes = torch.tensor([1, 4, 5, 7], device=device)
            assert torch.all(
                torch.eq(fsa2.scores.grad[fsa2_best_arc_indexes],
                         torch.ones(4, device=device)))
            assert fsa2.scores.grad.sum() == 4

            fsa3_best_arc_indexes = torch.tensor([1, 3], device=device)
            assert torch.all(
                torch.eq(fsa3.scores.grad[fsa3_best_arc_indexes],
                         torch.ones(2, device=device)))
            assert fsa3.scores.grad.sum() == 2
Ejemplo n.º 18
0
    def __call__(
        self, batch: Dict[str, Union[torch.Tensor, np.ndarray]]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            batch: Input speech data and corresponding lengths
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        if isinstance(batch["speech"], np.ndarray):
            batch["speech"] = torch.tensor(batch["speech"])
        if isinstance(batch["speech_lengths"], np.ndarray):
            batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, encoder_out_lens = self.asr_model.encode(**batch)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2
        )

        # It maybe useful to tune blank_bias.
        # The valid range of blank_bias is [-inf, 0]
        logp_encoder_output[:, :, 0] += self.blank_bias

        batch_size = encoder_out_lens.size(0)
        sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(torch.int32)
        start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t()
        num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32)
        supervision_segments = torch.cat([sequence_idx, start_frame, num_frames], dim=1)

        supervision_segments = supervision_segments.to(torch.int32)

        # An introduction to DenseFsaVec:
        # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector
        # It could be viewed as a fsa-type lopg_encoder_output,
        # whose weight on the arcs are initialized with logp_encoder_output.
        # The goal of converting tensor-type to fsa-type is using
        # fsa related functions in k2. e.g. k2.intersect_dense_pruned below
        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments)

        # The term "intersect" is similar to "compose" in k2.
        # The differences is are:
        # for "compose" functions, the composition involves
        # mathcing output label of a.fsa and input label of b.fsa
        # while for "intersect" functions, the composition involves
        # matching input label of a.fsa and input label of b.fsa
        # Actually, in compose functions, b.fsa is inverted and then
        # a.fsa and inv_b.fsa are intersected together.
        # For difference between compose and interset:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/fsa_algo.py#L308
        # For definition of k2.intersect_dense_pruned:
        # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/autograd.py#L648
        lattices = k2.intersect_dense_pruned(
            self.decode_graph,
            dense_fsa_vec,
            self.search_beam_size,
            self.output_beam_size,
            self.min_active_states,
            self.max_active_states,
        )

        # lattices.scores is the sum of decode_graph.scores(a.k.a. lm weight) and
        # dense_fsa_vec.scores(a.k.a. am weight) on related arcs.
        # For ctc decoding graph, lattices.scores only store am weight
        # since the decoder_graph only define the ctc topology and
        # has no lm weight on its arcs.
        # While for 3-gram decoding, whose graph is converted from language models,
        # lattice.scores contains both am weights and lm weights
        #
        # It maybe useful to tune lattice.scores
        # The valid range of lattice_weight is [0, inf)
        # The lattice_weight will affect the search of k2.random_paths
        lattices.scores *= self.lattice_weight

        results = []
        if self.use_nbest_rescoring:
            (
                am_scores,
                lm_scores,
                token_ids,
                new2old,
                path_to_seq_map,
                seq_to_path_splits,
            ) = nbest_am_lm_scores(
                lattices, self.num_paths, self.device, self.nbest_batch_size
            )

            ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device)
            max_token_length = max(ys_pad_lens)
            ys_pad_list = []
            for hyp in token_ids:
                ys_pad_list.append(
                    torch.cat(
                        [
                            torch.tensor(hyp, dtype=torch.long),
                            torch.tensor(
                                [self.asr_model.ignore_id]
                                * (max_token_length.item() - len(hyp)),
                                dtype=torch.long,
                            ),
                        ]
                    )
                )

            ys_pad = (
                torch.stack(ys_pad_list).to(torch.long).to(self.device)
            )  # [batch, max_token_length]

            encoder_out = enc.index_select(0, path_to_seq_map.to(torch.long)).to(
                self.device
            )  # [batch, T, dim]
            encoder_out_lens = encoder_out_lens.index_select(
                0, path_to_seq_map.to(torch.long)
            ).to(
                self.device
            )  # [batch]

            decoder_scores = -self.asr_model.batchify_nll(
                encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size
            )

            # padded_value for nnlm is 0
            ys_pad[ys_pad == self.asr_model.ignore_id] = 0
            nnlm_nll, x_lengths = self.lm.batchify_nll(
                ys_pad, ys_pad_lens, self.nll_batch_size
            )
            nnlm_scores = -nnlm_nll.sum(dim=1)

            batch_tot_scores = (
                self.am_weight * am_scores
                + self.decoder_weight * decoder_scores
                + self.nnlm_weight * nnlm_scores
            )
            split_size = indices_to_split_size(
                seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0)
            )
            batch_tot_scores = torch.split(
                batch_tot_scores,
                split_size,
            )

            hyps = []
            scores = []
            processed_seqs = 0
            for tot_scores in batch_tot_scores:
                if tot_scores.nelement() == 0:
                    # the last element by torch.tensor_split may be empty
                    # e.g.
                    # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4]))
                    # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64))
                    break
                best_seq_idx = processed_seqs + torch.argmax(tot_scores)

                assert best_seq_idx < len(token_ids)
                best_token_seqs = token_ids[best_seq_idx]
                processed_seqs += tot_scores.nelement()
                hyps.append(best_token_seqs)
                scores.append(tot_scores.max().item())

            assert len(hyps) == len(split_size)
        else:
            best_paths = k2.shortest_path(lattices, use_double_scores=True)
            scores = best_paths.get_tot_scores(
                use_double_scores=True, log_semiring=False
            ).tolist()
            hyps = get_texts(best_paths)

        assert len(scores) == len(hyps)

        for token_int, score in zip(hyps, scores):
            # For decoding methods nbest_rescoring and ctc_decoding
            # hyps stores token_index, which is lattice.labels.

            # convert token_id to text with self.tokenizer
            token = self.converter.ids2tokens(token_int)
            assert self.tokenizer is not None
            text = self.tokenizer.tokens2text(token)
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results
Ejemplo n.º 19
0
def decode(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: Union[str, torch.device],
    HCLG: Fsa,
):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []  # a list of pair [ref_labels, hyp_labels]
    for batch_idx, batch in enumerate(dataloader):
        feature = batch["inputs"]  # (N, T, C)
        supervisions = batch["supervisions"]

        feature = feature.to(device)

        # Since we are decoding with a k2 graph here, we need to create appropriate
        # supervisions. The segments need to be ordered in decreasing order of
        # length (although in our case all segments are of same length)
        supervision_segments = torch.stack(
            (
                supervisions["sequence_idx"],
                torch.floor_divide(supervisions["start_frame"],
                                   model.subsampling_factor),
                torch.floor_divide(supervisions["duration"],
                                   model.subsampling_factor),
            ),
            1,
        ).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]

        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)

        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert HLG.is_cuda()
        assert (
            HCLG.device == nnet_output.device
        ), f"Check failed: HCLG.device ({HCLG.device}) == nnet_output.device ({nnet_output.device})"

        lattices = k2.intersect_dense_pruned(HCLG, dense_fsa_vec, 20.0, 7.0,
                                             30, 10000)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == supervisions["is_voice"].shape[0]

        # best_paths is an FsaVec, and each of its FSAs is a linear FSA
        references = supervisions["is_voice"][indices]
        for i in range(references.shape[0]):
            ref = references[i, :]
            hyp = k2.arc_sort(
                best_paths[i]).arcs_as_tensor()[:-1, 2].detach().cpu()
            assert (
                ref.shape[0] == hyp.shape[0]
            ), "reference and hypothesis have unequal number of frames, {} vs. {}".format(
                ref.shape[0], hyp.shape[0])
            results.append((supervisions["cut"][indices[i]], ref, hyp))

        if batch_idx % 10 == 0:
            logging.info(
                "batch {}, cuts processed until now is {}/{} ({:.6f}%)".format(
                    batch_idx,
                    num_cuts,
                    tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100,
                ))

        num_cuts += supervisions["is_voice"].shape[0]

    return results
Ejemplo n.º 20
0
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray]
    ) -> List[Tuple[Optional[str], List[str], List[int], float]]:
        """Inference

        Args:
            data: Input speech data
        Returns:
            text, token, token_int, hyp

        """
        assert check_argument_types()

        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)

        # data: (Nsamples,) -> (1, Nsamples)
        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        # lenghts: (1,)
        lengths = speech.new_full([1],
                                  dtype=torch.long,
                                  fill_value=speech.size(1))
        batch = {"speech": speech, "speech_lengths": lengths}

        # a. To device
        batch = to_device(batch, device=self.device)

        # b. Forward Encoder
        # enc: [N, T, C]
        enc, _ = self.asr_model.encode(**batch)
        assert len(enc) == 1, len(enc)

        # logp_encoder_output: [N, T, C]
        logp_encoder_output = torch.nn.functional.log_softmax(
            self.asr_model.ctc.ctc_lo(enc), dim=2)

        # TODO(Liyong Guo): Support batch decoding.
        # Following statement only support batch_size == 1
        supervision_segments = torch.tensor([[0, 0, enc.shape[1]]],
                                            dtype=torch.int32)
        indices = torch.tensor([0])

        dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output,
                                       supervision_segments)

        lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec,
                                             20.0, self.output_beam_size, 30,
                                             10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        scores = best_paths.get_tot_scores(use_double_scores=True,
                                           log_semiring=False).tolist()

        hyps = get_texts(best_paths, indices)
        # TODO(Liyong Guo): Support batch decoding. now batch_size == 1.
        assert len(scores) == 1
        assert len(scores) == len(hyps)

        results = []

        for token_int, score in zip(hyps, scores):
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)

            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, score))

        assert check_return_type(results)
        return results