Exemple #1
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))

    transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys])
    transcripts_lexicon = k2.intersect(transcripts, L_inv)
    transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon))
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Exemple #2
0
def decode(dataloader: torch.utils.data.DataLoader,
           model: None,
           device: Union[str, torch.device],
           ctc_topo: None,
           numericalizer=None,
           num_paths=-1,
           output_beam_size: float=8):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []
    for batch_idx, batch in enumerate(dataloader):
        assert isinstance(batch, dict), type(batch)
        feature = batch['inputs']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
             (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32)
        supervision_segments = torch.clamp(supervision_segments, min=0)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        nnet_output = nnet_output.permute(0, 2, 1)

        # TODO(Liyong Guo): Tune this bias
        # blank_bias = 0.0
        # nnet_output[:, :, 0] += blank_bias

        with torch.no_grad():
            dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

            lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0,
                                                 output_beam_size, 30, 10000)

        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            pieces = [numericalizer.tokens_list[token_id] for token_id in hyps[i]]
            hyp_words = numericalizer.tokenizer.DecodePieces(pieces).split(' ')
            ref_words = texts[i].split(' ')
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
                    batch_idx, num_cuts, tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100))
        num_cuts += len(texts)
    return results
Exemple #3
0
    def decode(
            self, cuts: Union[AnyCut,
                              CutSet]) -> List[Tuple[List[str], List[str]]]:
        """
        Perform decoding with an n-gram language model (HLG graph).
        Doesn't support rescoring at this time.
        """
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        word_results = []
        # Hacky way to get batch quickly... we may need to improve on this.
        batch = K2SpeechRecognitionDataset(cuts,
                                           input_strategy=OnTheFlyFeatures(
                                               self.extractor),
                                           check_inputs=False)[list(cuts.ids)]
        features = batch['inputs'].permute(0, 2, 1).to(
            self.device)  # (B, T, F) -> (B, F, T)
        supervision_segments, texts = encode_supervisions(
            batch['supervisions'])

        # Forward pass through the acoustic model
        posteriors, _, _ = self.model(features)
        posteriors = posteriors.permute(0, 2, 1)  # (B, F, T) -> (B, T, F)

        # Wrapping into k2 "dense FSA" (representing PPG as a dense graph)
        dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments)

        # The actual decoding starts here:
        # First, we intersect the HLG and the PPG
        # with default pruning/beam search params from snowfall
        # The result is a batch of graphs (lattices)
        lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8,
                                             30, 10000)
        # ... then we find the shortest paths in the lattices ...
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        # ... and convert them to words with a convenience wrapper from snowfall
        hyps = get_texts(best_paths, torch.arange(len(texts)))

        # Here we read out the words from the best path graphs
        for i in range(len(texts)):
            hyp_words = [self.lexicon.words.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            word_results.append((ref_words, hyp_words))
        return word_results
Exemple #4
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))
    device = L_inv.device

    transcripts = k2.create_fsa_vec(
        [k2.linear_fsa(x, device=device) for x in ys])
    transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)

    transcripts_lexicon = k2.intersect(L_inv,
                                       transcripts_with_self_loops,
                                       treat_epsilons_specially=False)
    # Don't call invert_() above because we want to return phone IDs,
    # which is the `aux_labels` of transcripts_lexicon
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.top_sort(transcripts_lexicon)

    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Exemple #5
0
def decode(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: Union[str, torch.device],
    HLG: Fsa,
    symbols: SymbolTable,
):
    num_batches = None
    try:
        num_batches = len(dataloader)
    except TypeError:
        pass
    num_cuts = 0
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch["inputs"]
        supervisions = batch["supervisions"]
        supervision_segments = torch.stack(
            (
                supervisions["sequence_idx"],
                torch.floor_divide(supervisions["start_frame"],
                                   model.subsampling_factor),
                torch.floor_divide(supervisions["num_frames"],
                                   model.subsampling_factor),
            ),
            1,
        ).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions["text"]
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        blank_bias = -3.0
        nnet_output[:, :, 0] += blank_bias

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert HLG.is_cuda()
        assert (
            HLG.device == nnet_output.device
        ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
                                             10000)

        # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == len(texts)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            hyp_words = [symbols.get(x) for x in hyps[i]]
            ref_words = texts[i].split(" ")
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}"
            logging.info(
                f"batch {batch_str}, number of cuts processed until now is {num_cuts}"
            )

        num_cuts += len(texts)

    return results
def decode_one_batch(batch: Dict[str, Any],
                     model: AcousticModel,
                     HLG: k2.Fsa,
                     output_beam_size: float,
                     num_paths: int,
                     use_whole_lattice: bool,
                     G: Optional[k2.Fsa] = None) -> Dict[str, List[List[int]]]:
    '''
    Decode one batch and return the result in a dict. The dict has the
    following format:

        - key: It indicates the setting used for decoding. For example,
               if no rescoring is used, the key is the string `no_rescore`.
               If LM rescoring is used, the key is the string `lm_scale_xxx`,
               where `xxx` is the value of `lm_scale`. An example key is
               `lm_scale_0.7`
        - value: It contains the decoding result. `len(value)` equals to
                 batch size. `value[i]` is the decoding result for the i-th
                 utterance in the given batch.

    Args:
      batch:
        It is the return value from iterating
        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
        for the format of the `batch`.
      model:
        The neural network model.
      HLG:
        The decoding graph.
      output_beam_size:
        Size of the beam for pruning.
      use_whole_lattice:
        If True, `G` must not be None and it will use whole lattice for
        LM rescoring.
        If False and if `G` is not None, then `num_paths` must be positive
        and it will use n-best list for LM rescoring.
      num_paths:
        It specifies the size of `n` in n-best list decoding.
      G:
        The LM. If it is None, no rescoring is used.
        Otherwise, LM rescoring is used.
        It supports two types of LM rescoring: n-best list rescoring
        and whole lattice rescoring.
        `use_whole_lattice` specifies which type to use.

    Returns:
      Return the decoding result. See above description for the format of
      the returned dict.
    '''
    device = HLG.device
    feature = batch['inputs']
    assert feature.ndim == 3
    feature = feature.to(device)

    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]

    supervisions = batch['supervisions']

    nnet_output, _, _ = model(feature, supervisions)
    # nnet_output is [N, C, T]

    nnet_output = nnet_output.permute(0, 2, 1)
    # now nnet_output is [N, T, C]

    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
         (((supervisions['num_frames'] - 1) // 2 - 1) // 2)),
        1).to(torch.int32)

    supervision_segments = torch.clamp(supervision_segments, min=0)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0,
                                         output_beam_size, 30, 10000)

    if G is None:
        if num_paths > 1:
            best_paths = nbest_decoding(lattices, num_paths)
            key = f'no_rescore-{num_paths}'
        else:
            key = 'no_rescore'
            best_paths = k2.shortest_path(lattices, use_double_scores=True)
        hyps = get_texts(best_paths, indices)
        return {key: hyps}

    lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]

    if use_whole_lattice:
        best_paths_dict = rescore_with_whole_lattice(lattices, G,
                                                     lm_scale_list)
    else:
        best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths,
                                                   lm_scale_list)
    # best_paths_dict is a dict
    #  - key: lm_scale_xxx, where xxx is the value of lm_scale. An example
    #         key is lm_scale_1.2
    #  - value: it is the best path obtained using the corresponding lm scale
    #           from the dict key.

    ans = dict()
    for lm_scale_str, best_paths in best_paths_dict.items():
        hyps = get_texts(best_paths, indices)
        ans[lm_scale_str] = hyps
    return ans