def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch['inputs']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             torch.floor_divide(supervisions['start_frame'],
                                model.subsampling_factor),
             torch.floor_divide(supervisions['num_frames'],
                                model.subsampling_factor)), 1).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        #  blank_bias = -3.0
        #  nnet_output[:, :, 0] += blank_bias

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert LG.is_cuda()
        assert HLG.device == nnet_output.device, \
            f"Check failed: LG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
                                             10000)

        # lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == len(texts)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            hyp_words = [symbols.get(x) for x in hyps[i]]
            ref_words = texts[i].split(' ')
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
                    batch_idx, num_cuts, tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100))

        num_cuts += len(texts)

    return results
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           HLG: Fsa, symbols: SymbolTable, num_paths: int, G: k2.Fsa,
           use_whole_lattice: bool, output_beam_size: float):
    num_cuts = 0
    results = defaultdict(list)
    # results is a dict whose keys and values are:
    #  - key: It indicates the lm_scale, e.g., lm_scale_1.2.
    #         If no rescoring is used, the key is the literal string: no_rescore
    #
    #  - value: It is a list of tuples (ref_words, hyp_words)

    num_batches = None
    try:
        num_batches = len(dataloader)
    except TypeError:
        pass

    for batch_idx, batch in enumerate(dataloader):
        # We remove the non-tensor valeus under key 'text' which enables this
        # to run with TorchScript models.
        texts = batch['supervisions'].pop('text')

        hyps_dict = decode_one_batch(batch=batch,
                                     model=model,
                                     HLG=HLG,
                                     output_beam_size=output_beam_size,
                                     num_paths=num_paths,
                                     use_whole_lattice=use_whole_lattice,
                                     G=G)

        for lm_scale, hyps in hyps_dict.items():
            this_batch = []
            assert len(hyps) == len(texts)

            for i in range(len(texts)):
                hyp_words = [symbols.get(x) for x in hyps[i]]
                ref_words = texts[i].split(' ')
                this_batch.append((ref_words, hyp_words))

            results[lm_scale].extend(this_batch)

        num_cuts += len(texts)

        if batch_idx % 10 == 0:
            batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}"
            logging.info(
                f"batch {batch_str}, number of cuts processed until now is {num_cuts}"
            )

    return results
Beispiel #3
0
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable):
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch['features']
        supervisions = batch['supervisions']
        supervision_segments = torch.stack(
            (supervisions['sequence_idx'],
             torch.floor_divide(supervisions['start_frame'],
                                model.subsampling_factor),
             torch.floor_divide(supervisions['num_frames'],
                                model.subsampling_factor)), 1).to(torch.int32)
        texts = supervisions['text']
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        assert LG.is_cuda()
        assert LG.device == nnet_output.device, \
            f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0,
                                             30, 300)
        best_paths = k2.shortest_path(lattices, use_float_scores=True)
        best_paths = best_paths.to('cpu')
        assert best_paths.shape[0] == len(texts)

        for i in range(len(texts)):
            hyp_words = [
                symbols.get(x) for x in best_paths[i].aux_labels if x > 0
            ]
            results.append((texts[i].split(' '), hyp_words))

        if batch_idx % 10 == 0:
            logging.info('Processed batch {}/{} ({:.6f}%)'.format(
                batch_idx, len(dataloader),
                float(batch_idx) / len(dataloader) * 100))

    return results
Beispiel #4
0
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
           HLG: Fsa, symbols: SymbolTable, num_paths: int, G: k2.Fsa,
           use_whole_lattice: bool, output_beam_size: float):
    tot_num_cuts = len(dataloader.dataset.cuts)
    num_cuts = 0
    results = defaultdict(list)
    # results is a dict whose keys and values are:
    #  - key: It indicates the lm_scale, e.g., lm_scale_1.2.
    #         If no rescoring is used, the key is the literal string: no_rescore
    #
    #  - value: It is a list of tuples (ref_words, hyp_words)

    for batch_idx, batch in enumerate(dataloader):
        texts = batch['supervisions']['text']

        hyps_dict = decode_one_batch(batch=batch,
                                     model=model,
                                     HLG=HLG,
                                     output_beam_size=output_beam_size,
                                     num_paths=num_paths,
                                     use_whole_lattice=use_whole_lattice,
                                     G=G)

        for lm_scale, hyps in hyps_dict.items():
            this_batch = []
            assert len(hyps) == len(texts)

            for i in range(len(texts)):
                hyp_words = [symbols.get(x) for x in hyps[i]]
                ref_words = texts[i].split(' ')
                this_batch.append((ref_words, hyp_words))

            results[lm_scale].extend(this_batch)

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
                    batch_idx, num_cuts, tot_num_cuts,
                    float(num_cuts) / tot_num_cuts * 100))

        num_cuts += len(texts)

    return results
Beispiel #5
0
def decode(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: Union[str, torch.device],
    HLG: Fsa,
    symbols: SymbolTable,
):
    num_batches = None
    try:
        num_batches = len(dataloader)
    except TypeError:
        pass
    num_cuts = 0
    results = []  # a list of pair (ref_words, hyp_words)
    for batch_idx, batch in enumerate(dataloader):
        feature = batch["inputs"]
        supervisions = batch["supervisions"]
        supervision_segments = torch.stack(
            (
                supervisions["sequence_idx"],
                torch.floor_divide(supervisions["start_frame"],
                                   model.subsampling_factor),
                torch.floor_divide(supervisions["num_frames"],
                                   model.subsampling_factor),
            ),
            1,
        ).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]
        texts = supervisions["text"]
        assert feature.ndim == 3

        feature = feature.to(device)
        # at entry, feature is [N, T, C]
        feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
        with torch.no_grad():
            nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        blank_bias = -3.0
        nnet_output[:, :, 0] += blank_bias

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
        # assert HLG.is_cuda()
        assert (
            HLG.device == nnet_output.device
        ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
        # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
        # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
        lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
                                             10000)

        # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0)
        best_paths = k2.shortest_path(lattices, use_double_scores=True)
        assert best_paths.shape[0] == len(texts)
        hyps = get_texts(best_paths, indices)
        assert len(hyps) == len(texts)

        for i in range(len(texts)):
            hyp_words = [symbols.get(x) for x in hyps[i]]
            ref_words = texts[i].split(" ")
            results.append((ref_words, hyp_words))

        if batch_idx % 10 == 0:
            batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}"
            logging.info(
                f"batch {batch_str}, number of cuts processed until now is {num_cuts}"
            )

        num_cuts += len(texts)

    return results
Beispiel #6
0
def _ids_to_symbols(ids: List[int], symbol_table: k2.SymbolTable) -> List[str]:
    '''Convert a list of IDs to a list of symbols.
    '''
    return [symbol_table.get(i) for i in ids]