示例#1
0
文件: lm_decoder.py 项目: herobd/atr
    def __init__(self, idx_to_char, params={}):

        self.idx_to_char = idx_to_char

        self.reorder_1, self.reorder_2 = create_phone_map(
            params['phones_path'], idx_to_char)
        self.word_syms = SymbolTable.read_text(params['words_path'])

        self.acoustic_scale = params.get('acoustic', 1.2)
        if self.acoustic_scale < 0:
            print("Warning: acoustic scale is less than 0")
        allow_partial = params.get('allow_partial', True)
        beam = params.get('beam', 13)
        self.alphaweight = params.get('alphaweight', 0.3)

        trans_model = TransitionModel()
        with xopen(params['mdl_path']) as ki:
            trans_model.read(ki.stream(), ki.binary)

        decoder_opts = FasterDecoderOptions()
        decoder_opts.beam = beam

        decode_fst = read_fst_kaldi(params['fst_path'])

        self.decoder_opts = decoder_opts
        self.trans_model = trans_model
        self.decode_fst = decode_fst

        self.stats = LMStats()
        self.stats_state = None
        self.add_stats_phase = True
示例#2
0
    def LoadModels(self):
        try:
            # Define online feature pipeline
            po = ParseOptions("")

            decoder_opts = LatticeFasterDecoderOptions()
            self.endpoint_opts = OnlineEndpointConfig()
            self.decodable_opts = NnetSimpleLoopedComputationOptions()
            feat_opts = OnlineNnetFeaturePipelineConfig()

            decoder_opts.register(po)
            self.endpoint_opts.register(po)
            self.decodable_opts.register(po)
            feat_opts.register(po)

            po.read_config_file(self.CONFIG_FILES_PATH + "/online.conf")
            self.feat_info = OnlineNnetFeaturePipelineInfo.from_config(
                feat_opts)

            # Set metadata parameters
            self.samp_freq = self.feat_info.mfcc_opts.frame_opts.samp_freq
            self.frame_shift = self.feat_info.mfcc_opts.frame_opts.frame_shift_ms / 1000
            self.acwt = self.decodable_opts.acoustic_scale

            # Load Acoustic and graph models and other files
            self.transition_model, self.acoustic_model = NnetRecognizer.read_model(
                self.AM_PATH + "/final.mdl")
            graph = _fst.read_fst_kaldi(self.LM_PATH + "/HCLG.fst")
            self.decoder_graph = LatticeFasterOnlineDecoder(
                graph, decoder_opts)
            self.symbols = _fst.SymbolTable.read_text(self.LM_PATH +
                                                      "/words.txt")
            self.info = WordBoundaryInfo.from_file(
                WordBoundaryInfoNewOpts(), self.LM_PATH + "/word_boundary.int")

            self.asr = NnetLatticeFasterOnlineRecognizer(
                self.transition_model,
                self.acoustic_model,
                self.decoder_graph,
                self.symbols,
                decodable_opts=self.decodable_opts,
                endpoint_opts=self.endpoint_opts)
            del graph, decoder_opts
        except Exception as e:
            self.log.error(e)
            raise ValueError(
                "AM and LM loading failed!!! (see logs for more details)")
示例#3
0
def gmm_decode_faster(model_rxfilename, fst_rxfilename,
                      feature_rspecifier, words_wspecifier,
                      alignment_wspecifier="", lattice_wspecifier="",
                      word_symbol_table="", acoustic_scale=0.1,
                      allow_partial=True, decoder_opts=FasterDecoderOptions()):
    # Read model.
    trans_model = TransitionModel()
    am_gmm = AmDiagGmm()
    with xopen(model_rxfilename) as ki:
        trans_model.read(ki.stream(), ki.binary)
        am_gmm.read(ki.stream(), ki.binary)

    # Open table readers/writers.
    feature_reader = SequentialMatrixReader(feature_rspecifier)
    words_writer = IntVectorWriter(words_wspecifier)
    alignment_writer = IntVectorWriter(alignment_wspecifier)
    clat_writer = CompactLatticeWriter(lattice_wspecifier)

    # Read symbol table.
    word_syms = None
    if word_symbol_table != "":
        word_syms = SymbolTable.read_text(word_symbol_table)
        if not word_syms:
            raise RuntimeError("Could not read symbol table from file {}"
                               .format(word_symbol_table))

    # NOTE:
    # It is important to read decode_fst after opening feature reader as
    # it can prevent crashes on systems without enough virtual memory.

    # Read decoding graph and instantiate decoder.
    decode_fst = read_fst_kaldi(fst_rxfilename)
    decoder = FasterDecoder(decode_fst, decoder_opts)

    tot_like = 0.0
    frame_count = 0
    num_success, num_fail = 0, 0
    start = time.time()

    for key, features in feature_reader:
        if features.num_rows == 0:
            num_fail += 1
            logging.warning("Zero-length utterance: {}".format(key))
            continue

        gmm_decodable = DecodableAmDiagGmmScaled(am_gmm, trans_model,
                                                 features, acoustic_scale)
        decoder.decode(gmm_decodable)

        if not (allow_partial or decoder.reached_final()):
            num_fail += 1
            logging.warning("Did not successfully decode utterance {}, len = {}"
                            .format(key, features.num_rows))
            continue

        try:
            best_path = decoder.get_best_path()
        except RuntimeError:
            num_fail += 1
            logging.warning("Did not successfully decode utterance {}, len = {}"
                            .format(key, features.num_rows))
            continue

        if not decoder.reached_final():
            logging.warning("Decoder did not reach end-state, outputting "
                            "partial traceback since --allow-partial=true")

        ali, words, weight = get_linear_symbol_sequence(best_path)

        words_writer[key] = words

        if alignment_writer.is_open():
            alignment_writer[key] = ali

        if clat_writer.is_open():
            if acoustic_scale != 0.0:
                scale = acoustic_lattice_scale(1.0 / acoustic_scale)
                scale_lattice(scale, best_path)
            best_path = convert_lattice_to_compact_lattice(best_path)
            clat_writer[key] = best_path

        if word_syms:
            syms = convert_indices_to_symbols(word_syms, words)
            print(key, " ".join(syms), file=sys.stderr)

        num_success += 1
        frame_count += features.num_rows
        like = - (weight.value1 + weight.value2);
        tot_like += like
        logging.info("Log-like per frame for utterance {} is {} over {} "
                     "frames.".format(key, like / features.num_rows,
                                      features.num_rows))
        logging.debug("Cost for utterance {} is {} + {}"
                      .format(key, weight.value1, weight.value2))

    elapsed = time.time() - start
    logging.info("Time taken [excluding initialization] {}s: real-time factor "
                 "assuming 100 frames/sec is {}"
                 .format(elapsed, elapsed * 100 / frame_count))
    logging.info("Done {} utterances, failed for {}"
                 .format(num_success, num_fail))
    logging.info("Overall log-likelihood per frame is {} over {} frames."
                 .format(tot_like / frame_count, frame_count))

    feature_reader.close()
    words_writer.close()
    if alignment_writer.is_open():
        alignment_writer.close()
    if clat_writer.is_open():
        clat_writer.close()

    return True if num_success != 0 else False
示例#4
0
    acoustic_model = AmDiagGmm().read(ki.stream(), ki.binary)


# Define the decodable wrapper: (features, acoustic_scale) -> decodable
def make_decodable_wrapper(trans_model, acoustic_model):
    def decodable_wrapper(features, acoustic_scale):
        return DecodableAmDiagGmmScaled(acoustic_model, trans_model, features,
                                        acoustic_scale)

    return decodable_wrapper


decodable_wrapper = make_decodable_wrapper(trans_model, acoustic_model)

# Define the decoder
decoding_graph = read_fst_kaldi("models/mono/graph/HCLG.fst")
decoder_opts = LatticeFasterDecoderOptions()
decoder_opts.beam = 13.0
decoder_opts.lattice_beam = 6.0
decoder = LatticeFasterDecoder(decoding_graph, decoder_opts)

# Define the recognizer
symbols = SymbolTable.read_text("models/mono/graph/words.txt")
asr = Recognizer(decoder, decodable_wrapper, symbols)

# Decode wave files
# for key, wav in SequentialWaveReader("scp:wav.scp"):
# feats = feat_pipeline(wav)
# out = asr.decode(feats)
# print(key, out["text"], flush=True)
示例#5
0
    def __init__(
        self,
        cfg: KaldiDecoderConfig,
        beam: int,
        nbest: int = 1,
    ):
        try:
            from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer
            from kaldi.base import set_verbose_level
            from kaldi.decoder import (
                FasterDecoder,
                FasterDecoderOptions,
                LatticeFasterDecoder,
                LatticeFasterDecoderOptions,
            )
            from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions
            from kaldi.fstext import read_fst_kaldi, SymbolTable
        except:
            warnings.warn(
                "pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi"
            )

        # set_verbose_level(2)

        self.acoustic_scale = cfg.acoustic_scale
        self.nbest = nbest

        if cfg.hlg_graph_path is None:
            assert (
                cfg.kaldi_initializer_config is not None
            ), "Must provide hlg graph path or kaldi initializer config"
            cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config)

        assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path

        if cfg.is_lattice:
            self.dec_cls = LatticeFasterDecoder
            opt_cls = LatticeFasterDecoderOptions
            self.rec_cls = LatticeFasterRecognizer
        else:
            assert self.nbest == 1, "nbest > 1 requires lattice decoder"
            self.dec_cls = FasterDecoder
            opt_cls = FasterDecoderOptions
            self.rec_cls = FasterRecognizer

        self.decoder_options = opt_cls()
        self.decoder_options.beam = beam
        self.decoder_options.max_active = cfg.max_active
        self.decoder_options.beam_delta = cfg.beam_delta
        self.decoder_options.hash_ratio = cfg.hash_ratio

        if cfg.is_lattice:
            self.decoder_options.lattice_beam = cfg.lattice_beam
            self.decoder_options.prune_interval = cfg.prune_interval
            self.decoder_options.determinize_lattice = cfg.determinize_lattice
            self.decoder_options.prune_scale = cfg.prune_scale
            det_opts = DeterminizeLatticePhonePrunedOptions()
            det_opts.max_mem = cfg.max_mem
            det_opts.phone_determinize = cfg.phone_determinize
            det_opts.word_determinize = cfg.word_determinize
            det_opts.minimize = cfg.minimize
            self.decoder_options.det_opts = det_opts

        self.output_symbols = {}
        with open(cfg.output_dict, "r") as f:
            for line in f:
                items = line.rstrip().split()
                assert len(items) == 2
                self.output_symbols[int(items[1])] = items[0]

        logger.info(f"Loading FST from {cfg.hlg_graph_path}")
        self.fst = read_fst_kaldi(cfg.hlg_graph_path)
        self.symbol_table = SymbolTable.read_text(cfg.output_dict)

        self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads)
示例#6
0
    acoustic_model = AmDiagGmm().read(ki.stream(), ki.binary)


# Define the decodable wrapper: (features, acoustic_scale) -> decodable
def make_decodable_wrapper(trans_model, acoustic_model):
    def decodable_wrapper(features, acoustic_scale):
        return DecodableAmDiagGmmScaled(acoustic_model, trans_model, features,
                                        acoustic_scale)

    return decodable_wrapper


decodable_wrapper = make_decodable_wrapper(trans_model, acoustic_model)

# Define the decoder
decoding_graph = read_fst_kaldi(
    "/home/dogan/tools/pykaldi/egs/models/wsj/HCLG.fst")
decoder_opts = FasterDecoderOptions()
decoder_opts.beam = 13
decoder_opts.max_active = 7000
decoder = FasterDecoder(decoding_graph, decoder_opts)

# Define the recognizer
symbols = SymbolTable.read_text(
    "/home/dogan/tools/pykaldi/egs/models/wsj/words.txt")
asr = Recognizer(decoder, decodable_wrapper, symbols)

# Decode wave files
for key, wav in SequentialWaveReader(
        "scp:/home/dogan/tools/pykaldi/egs/decoder/test2.scp"):
    feats = feat_pipeline(wav)
    out = asr.decode(feats)