Example #1
0
    def decode_one(self, logits, padding):
        from kaldi.matrix import Matrix

        decoder = self.dec_cls(self.fst, self.decoder_options)
        asr = self.rec_cls(decoder,
                           self.symbol_table,
                           acoustic_scale=self.acoustic_scale)

        if padding is not None:
            logits = logits[~padding]

        mat = Matrix(logits.numpy())

        out = asr.decode(mat)

        if self.nbest > 1:
            from kaldi.fstext import shortestpath
            from kaldi.fstext.utils import (
                convert_compact_lattice_to_lattice,
                convert_lattice_to_std,
                convert_nbest_to_list,
                get_linear_symbol_sequence,
            )

            lat = out["lattice"]

            sp = shortestpath(lat, nshortest=self.nbest)

            sp = convert_compact_lattice_to_lattice(sp)
            sp = convert_lattice_to_std(sp)
            seq = convert_nbest_to_list(sp)

            results = []
            for s in seq:
                _, o, w = get_linear_symbol_sequence(s)
                words = list(self.output_symbols[z] for z in o)
                results.append({
                    "tokens": words,
                    "words": words,
                    "score": w.value,
                    "emissions": logits,
                })
            return results
        else:
            words = out["text"].split()
            return [{
                "tokens": words,
                "words": words,
                "score": out["likelihood"],
                "emissions": logits,
            }]
Example #2
0
def asr(filenameS_hash, filenameS, asr_beamsize=13, asr_max_active=8000):
    models_dir = "models/"

    # Read yaml File
    config_file = "models/kaldi_tuda_de_nnet3_chain2.yaml"
    with open(config_file, 'r') as stream:
        model_yaml = yaml.safe_load(stream)
    decoder_yaml_opts = model_yaml['decoder']

    scp_filename = "tmp/%s.scp" % filenameS_hash
    wav_filename = "tmp/%s.wav" % filenameS_hash
    spk2utt_filename = "tmp/%s_spk2utt" % filenameS_hash

    # write scp file
    with open(scp_filename, 'w') as scp_file:
        scp_file.write("%s tmp/%s.wav\n" % (filenameS_hash, filenameS_hash))

    # write scp file
    with open(spk2utt_filename, 'w') as scp_file:
        scp_file.write("%s %s\n" % (filenameS_hash, filenameS_hash))

    # use ffmpeg to convert the input media file (any format!) to 16 kHz wav mono
    (
        ffmpeg
            .input(filename)
            .output("tmp/%s.wav" % filenameS_hash, acodec='pcm_s16le', ac=1, ar='16k')
            .overwrite_output()
            .run()
    )

    # Construct recognizer
    decoder_opts = LatticeFasterDecoderOptions()
    decoder_opts.beam = asr_beamsize
    decoder_opts.max_active = asr_max_active
    decodable_opts = NnetSimpleComputationOptions()
    decodable_opts.acoustic_scale = 1.0
    decodable_opts.frame_subsampling_factor = 3
    decodable_opts.frames_per_chunk = 150
    asr = NnetLatticeFasterRecognizer.from_files(
        models_dir + decoder_yaml_opts["model"],
        models_dir + decoder_yaml_opts["fst"],
        models_dir + decoder_yaml_opts["word-syms"],
        decoder_opts=decoder_opts, decodable_opts=decodable_opts)

    # Construct symbol table
    symbols = SymbolTable.read_text(models_dir + decoder_yaml_opts["word-syms"])
    phi_label = symbols.find_index("#0")

    # Define feature pipelines as Kaldi rspecifiers
    feats_rspec = ("ark:compute-mfcc-feats --config=%s scp:" + scp_filename + " ark:- |") % \
                  (models_dir + decoder_yaml_opts["mfcc-config"])
    ivectors_rspec = (
            ("ark:compute-mfcc-feats --config=%s scp:" + scp_filename + " ark:-"
             + " | ivector-extract-online2 --config=%s ark:" + spk2utt_filename + " ark:- ark:- |") %
            ((models_dir + decoder_yaml_opts["mfcc-config"]),
             (models_dir + decoder_yaml_opts["ivector-extraction-config"]))
    )

    did_decode = False
    # Decode wav files
    with SequentialMatrixReader(feats_rspec) as f, \
            SequentialMatrixReader(ivectors_rspec) as i:
        for (fkey, feats), (ikey, ivectors) in zip(f, i):
            did_decode = True
            assert (fkey == ikey)
            out = asr.decode((feats, ivectors))
            best_path = functions.compact_lattice_shortest_path(out["lattice"])
            words, _, _ = get_linear_symbol_sequence(shortestpath(best_path))
            timing = functions.compact_lattice_to_word_alignment(best_path)

    assert(did_decode)

    # Maps words to the numbers
    words = indices_to_symbols(symbols, timing[0])

    # Creates the datastructure (Word, begin(Frames), end(Frames))
    vtt = list(map(list, zip(words, timing[1], timing[2])))

    # Cleanup tmp files
    print('removing tmp file:', scp_filename)
    os.remove(scp_filename)
    print('removing tmp file:', wav_filename)
    os.remove(wav_filename)
    print('removing tmp file:', spk2utt_filename)
    os.remove(spk2utt_filename)
    return vtt, words
rnnlm_opts.eos_index = symbols.find_index("</s>")
rnnlm_opts.brk_index = symbols.find_index("<brk>")
compose_opts = ComposeLatticePrunedOptions()
compose_opts.lattice_compose_beam = 4
rescorer = LatticeRnnlmPrunedRescorer.from_files(
    "lm/G.carpa",
    "rnnlm-get-word-embedding lm/word_feats.txt lm/feat_embedding.final.mat -|",
    "lm/final.raw",
    acoustic_scale=1.0,
    max_ngram_order=4,
    use_const_arpa=True,
    opts=rnnlm_opts,
    compose_opts=compose_opts)

# Define feature pipelines as Kaldi rspecifiers
feats_rspec = "ark:compute-mfcc-feats --config=mfcc.conf scp:wav.scp ark:- |"
ivectors_rspec = (
    "ark:compute-mfcc-feats --config=mfcc.conf scp:wav.scp ark:-"
    " | ivector-extract-online2 --config=ivector.conf ark:spk2utt ark:- ark:- |"
)

# Decode wav files
with SequentialMatrixReader(feats_rspec) as f, \
     SequentialMatrixReader(ivectors_rspec) as i:
    for (fkey, feats), (ikey, ivectors) in zip(f, i):
        assert (fkey == ikey)
        out = asr.decode((feats, ivectors))
        rescored_lat = rescorer.rescore(out["lattice"])
        words, _, _ = get_linear_symbol_sequence(shortestpath(rescored_lat))
        print(fkey, " ".join(indices_to_symbols(symbols, words)), flush=True)