Ejemplo n.º 1
0
def test_on_the_fly_batch_feature_extraction(cut_set, extractor_type):
    from lhotse.dataset import OnTheFlyFeatures

    extractor = OnTheFlyFeatures(extractor=extractor_type())
    cut_set = cut_set.resample(16000)
    feats, feat_lens = extractor(cut_set)  # does not crash
    assert isinstance(feats, torch.Tensor)
    assert isinstance(feat_lens, torch.Tensor)
Ejemplo n.º 2
0
 def compute_features(self, cuts: Union[Cut, CutSet]) -> torch.Tensor:
     if isinstance(cuts, Cut):
         cuts = CutSet.from_cuts([cuts])
     assert cuts[0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'
     otf = OnTheFlyFeatures(self.extractor)
     # feats: (batch, seq_len, n_feats)
     feats, _ = otf(cuts)
     return feats
Ejemplo n.º 3
0
def test_on_the_fly_feats_return_audio(cut_set):
    from lhotse.dataset import OnTheFlyFeatures

    extractor = OnTheFlyFeatures(extractor=Fbank(), return_audio=True)
    cut_set = cut_set.resample(16000)
    feats, feat_lens, audios, audio_lens = extractor(cut_set)
    assert isinstance(feats, torch.Tensor)
    assert isinstance(feat_lens, torch.Tensor)
    assert isinstance(audios, torch.Tensor)
    assert isinstance(audio_lens, torch.Tensor)

    assert feats.shape == (2, 300, 80)
    assert feat_lens.shape == (2, )
    assert audios.shape == (2, 48000)
    assert audio_lens.shape == (2, )
Ejemplo n.º 4
0
    def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        """
        Perform forced alignment and return a tensor that represents a batch of frame-level alignments:
        >>> alignments = torch.tensor([
        ...     [0, 0, 0, 1, 57, 57, 35, 35, 35, ...],
        ...     [...],
        ...     ...
        ... ])

        :return: an int32 tensor with shape ``(batch_size, num_frames)``.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'

        cuts = cuts.map_supervisions(self.normalize_text)

        otf = OnTheFlyFeatures(self.extractor)
        feats, _ = otf(cuts)
        feats = feats.permute(0, 2, 1)
        texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts]

        # Compute AM posteriors
        # (batch, seq_len ~/ 4, num_phones)
        posteriors, _, _ = self.model(feats)
        # Note: we are using "dummy" supervisions so that the aligner also considers
        # the padding area. We can adjust that behaviour if needed by passing actual
        # supervision segments, but then we will have a ragged tensor (will need to
        # pad the alignments themselves).
        sups = self.dummy_supervisions(feats)
        posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups)

        # Intersection with ground truth transcript graphs
        num, den = self.compiler.compile(texts, self.P)
        alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0)
        best_path = k2.shortest_path(alignment, use_double_scores=True)

        # Retrieve sequences of phone IDs per frame
        # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels)
        frame_labels = torch.stack(
            [best_path[i].labels[:-1] for i in range(best_path.shape[0])])
        return frame_labels
Ejemplo n.º 5
0
    def decode(
            self, cuts: Union[AnyCut,
                              CutSet]) -> List[Tuple[List[str], List[str]]]:
        """
        Perform decoding with an n-gram language model (HLG graph).
        Doesn't support rescoring at this time.
        """
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        word_results = []
        # Hacky way to get batch quickly... we may need to improve on this.
        batch = K2SpeechRecognitionDataset(cuts,
                                           input_strategy=OnTheFlyFeatures(
                                               self.extractor),
                                           check_inputs=False)[list(cuts.ids)]
        features = batch['inputs'].permute(0, 2, 1).to(
            self.device)  # (B, T, F) -> (B, F, T)
        supervision_segments, texts = encode_supervisions(
            batch['supervisions'])

        # Forward pass through the acoustic model
        posteriors, _, _ = self.model(features)
        posteriors = posteriors.permute(0, 2, 1)  # (B, F, T) -> (B, T, F)

        # Wrapping into k2 "dense FSA" (representing PPG as a dense graph)
        dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments)

        # The actual decoding starts here:
        # First, we intersect the HLG and the PPG
        # with default pruning/beam search params from snowfall
        # The result is a batch of graphs (lattices)
        lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8,
                                             30, 10000)
        # ... then we find the shortest paths in the lattices ...
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        # ... and convert them to words with a convenience wrapper from snowfall
        hyps = get_texts(best_paths, torch.arange(len(texts)))

        # Here we read out the words from the best path graphs
        for i in range(len(texts)):
            hyp_words = [self.lexicon.words.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            word_results.append((ref_words, hyp_words))
        return word_results
Ejemplo n.º 6
0
    def compute_posteriors(self, cuts: Union[Cut, CutSet]) -> torch.Tensor:
        """
        Run the forward pass of the acoustic model and return a tensor representing a batch of phone posteriorgrams.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, Cut):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'
        otf = OnTheFlyFeatures(self.extractor)
        # feats: (batch, seq_len, n_feats)
        feats, _ = otf(cuts)
        # feats: (batch, n_feats, seq_len)
        feats = feats.permute(0, 2, 1)

        # Compute AM posteriors
        # posteriors: (batch, n_phones, ~seq_len / 4)
        posteriors, _, _ = self.model(feats)
        # returns: (batch, ~seq_len / 4, n_phones)
        return posteriors.permute(0, 2, 1)