Beispiel #1
0
def test_kaldi_audio(wav_file, audio, dtype):
    # make sure we have results when loading a wav file with
    # shennong.Audio and with the Kaldi code.
    with tempfile.NamedTemporaryFile('w+') as tfile:
        tfile.write('test {}\n'.format(wav_file))
        tfile.seek(0)
        with SequentialWaveReader('scp,t:' + tfile.name) as reader:
            for key, wave in reader:
                audio_kaldi = Audio(wave.data().numpy().reshape(
                    audio.data.shape),
                                    audio.sample_rate,
                                    validate=False)

    audio = audio.astype(dtype)
    assert audio.duration == audio_kaldi.duration
    assert audio.dtype == dtype
    assert audio.is_valid()
    assert audio_kaldi.dtype == np.float32
    assert not audio_kaldi.is_valid()  # not in [-1, 1] but [-2**15, 2**15-1]

    mfcc = MfccProcessor().process(audio)
    mfcc_kaldi = MfccProcessor().process(audio_kaldi)
    assert mfcc.shape == mfcc_kaldi.shape
    assert np.array_equal(mfcc.times, mfcc_kaldi.times)
    assert mfcc.properties == mfcc_kaldi.properties
    assert mfcc.dtype == mfcc_kaldi.dtype
    assert pytest.approx(mfcc.data, mfcc_kaldi.data)
Beispiel #2
0
def lid_module(key, audio_file, start, end):
    # ==================================
    #       Get data and process it.
    # ==================================
    wav_spc = "scp:echo " + key + " 'sox -V0 -t wav " + audio_file + " -c 1 -r 8000 -t wav - trim " + str(start) + " " + str(
        float(end) - float(start)) + "|' |"
    hires_mfcc = Mfcc(hires_mfcc_opts)
    wav = SequentialWaveReader(wav_spc).value()
    hi_feat = hires_mfcc.compute_features(wav.data()[0], wav.samp_freq, 1.0)
    hi_feat = hi_feat.numpy() - CMVN
    X = hi_feat.T
    X = np.expand_dims(np.expand_dims(X, 0), -1)
    #print(X.shape)
    v = network_eval.predict(X)
    #print(v)
    #print(key, "::", i2l[v.argmax()])
    return i2l[v.argmax()]
Beispiel #3
0
def lid_module(key, audio_file, start, end):
    # ==================================
    #       Get data and process it.
    # ==================================
    wav_spc = "scp:echo " + key + " 'sox -V0 -t wav " + audio_file + " -c 1 -r 16000 -t wav - trim " + str(
        start) + " " + str(
        float(end) - float(start)) + "|' |"
    hires_mfcc = Mfcc(hires_mfcc_opts)
    wav = SequentialWaveReader(wav_spc).value()
    hi_feat = hires_mfcc.compute_features(wav.data()[0], wav.samp_freq, 1.0)
    hi_feat = hi_feat.numpy() - CMVN
    X = hi_feat.T
    print(X.shape)
    if X.shape[1] >= 384:
        X = np.expand_dims(X[:,:384], 0)
    else:
        padded_x = torch.zeros(40, 384)
        padded_x[:,:X.shape[1]]	 = torch.from_numpy(X)
        X = np.expand_dims(padded_x, 0)
    print(X.shape)
    emb = nn_LID_model_DA.emb(torch.from_numpy(X))[0]
    print(emb.shape)
Beispiel #4
0
def test_compare_kaldi(wav_file):
    a1 = Audio.load(wav_file).data

    with tempfile.NamedTemporaryFile('w+') as tfile:
        tfile.write('test {}\n'.format(wav_file))
        tfile.seek(0)
        with SequentialWaveReader('scp,t:' + tfile.name) as reader:
            for key, wave in reader:
                a2 = wave.data().numpy()

    assert a1.max() == a2.max()
    assert a1.min() == a2.min()
    assert len(a1) == len(a2.flatten()) == 22713
    assert a1.dtype == np.int16 and a2.dtype == np.float32
    assert a1.shape == (22713,) and a2.shape == (1, 22713)
    assert pytest.approx(a1, a2)
def extract_spec(filename,
                 samp_freq,
                 frame_length_ms=25,
                 frame_shift_ms=10,
                 round_to_power_of_two=True,
                 snip_edges=True):
    '''
    extract spectrogram using kaldi
    args:
        filename: wav file path
        samp_freq: sample frequence
    return:
        spectrogram: (frame, fre)
    '''
    # get rspec and wspec
    with open('wav.scp', 'w') as f:
        f.write('test1 ' + filename + '\n')
    rspec = 'scp,p:' + 'wav.scp'
    wspec = 'ark,t:' + 'spec.ark'
    # set po
    usage = """Extract MFCC features.Usage: example.py [opts...] <rspec> <wspec>"""
    po = ParseOptions(usage)
    po.register_float("min-duration", 0.0, "minimum segment duration")
    opts = po.parse_args()
    # set options
    spec_opts = SpectrogramOptions()
    spec_opts.frame_opts.samp_freq = samp_freq
    spec_opts.frame_opts.frame_length_ms = frame_length_ms
    spec_opts.frame_opts.frame_shift_ms = frame_shift_ms
    spec_opts.frame_opts.round_to_power_of_two = round_to_power_of_two
    spec_opts.frame_opts.snip_edges = snip_edges
    spec_opts.register(po)
    spec = Spectrogram(spec_opts)
    sf = spec_opts.frame_opts.samp_freq
    with SequentialWaveReader(rspec) as reader, MatrixWriter(wspec) as writer:
        for key, wav in reader:
            if wav.duration < opts.min_duration:
                continue
            assert (wav.samp_freq >= sf)
            assert (wav.samp_freq % sf == 0)
            s = wav.data()
            s = s[:, ::int(wav.samp_freq / sf)]
            m = SubVector(mean(s, axis=0))
            f = spec.compute_features(m, sf, 1.0)
            f_array = np.array(f)
            writer[key] = f
    return f_array
def decode_chunked_partial(scp):
    ## Decode (whole utterance)
    #for key, wav in SequentialWaveReader("scp:wav.scp"):
    #    feat_pipeline = OnlineNnetFeaturePipeline(feat_info)
    #    asr.set_input_pipeline(feat_pipeline)
    #    feat_pipeline.accept_waveform(wav.samp_freq, wav.data()[0])
    #    feat_pipeline.input_finished()
    #    out = asr.decode()
    #    print(key, out["text"], flush=True)

    # Decode (chunked + partial output)
    for key, wav in SequentialWaveReader("scp:wav.scp"):
        feat_pipeline = OnlineNnetFeaturePipeline(feat_info)
        asr.set_input_pipeline(feat_pipeline)
        asr.init_decoding()
        data = wav.data()[0]
        last_chunk = False
        part = 1
        prev_num_frames_decoded = 0
        for i in range(0, len(data), chunk_size):
            if i + chunk_size >= len(data):
                last_chunk = True
            feat_pipeline.accept_waveform(wav.samp_freq,
                                          data[i:i + chunk_size])
            if last_chunk:
                feat_pipeline.input_finished()
            asr.advance_decoding()
            num_frames_decoded = asr.decoder.num_frames_decoded()
            if not last_chunk:
                if num_frames_decoded > prev_num_frames_decoded:
                    prev_num_frames_decoded = num_frames_decoded
                    out = asr.get_partial_output()
                    print(key + "-part%d" % part, out["text"], flush=True)
                    part += 1
        asr.finalize_decoding()
        out = asr.get_output()
        print(key + "-final", out["text"], flush=True)
def decode_chunked_partial_endpointing(asr,
                                       feat_info,
                                       decodable_opts,
                                       scp,
                                       chunk_size=1024,
                                       compute_confidences=True,
                                       asr_client=None,
                                       speaker="Speaker",
                                       pad_confidences=True):
    # Decode (chunked + partial output + endpointing
    #         + ivector adaptation + silence weighting)
    adaptation_state = OnlineIvectorExtractorAdaptationState.from_info(
        feat_info.ivector_extractor_info)
    for key, wav in SequentialWaveReader(scp):
        feat_pipeline = OnlineNnetFeaturePipeline(feat_info)
        feat_pipeline.set_adaptation_state(adaptation_state)
        asr.set_input_pipeline(feat_pipeline)
        asr.init_decoding()
        sil_weighting = OnlineSilenceWeighting(
            asr.transition_model, feat_info.silence_weighting_config,
            decodable_opts.frame_subsampling_factor)
        data = wav.data()[0]
        print("type(data):", type(data))
        last_chunk = False
        utt, part = 1, 1
        prev_num_frames_decoded, offset = 0, 0
        for i in range(0, len(data), chunk_size):
            if i + chunk_size >= len(data):
                last_chunk = True
            feat_pipeline.accept_waveform(wav.samp_freq,
                                          data[i:i + chunk_size])
            if last_chunk:
                feat_pipeline.input_finished()
            if sil_weighting.active():
                sil_weighting.compute_current_traceback(asr.decoder)
                feat_pipeline.ivector_feature().update_frame_weights(
                    sil_weighting.get_delta_weights(
                        feat_pipeline.num_frames_ready()))
            asr.advance_decoding()
            num_frames_decoded = asr.decoder.num_frames_decoded()
            if not last_chunk:
                if asr.endpoint_detected():
                    asr.finalize_decoding()
                    out = asr.get_output()
                    mbr = MinimumBayesRisk(out["lattice"])
                    confd = mbr.get_one_best_confidences()
                    if pad_confidences:
                        token_length = len(out["text"].split())

                        # computed confidences array is smaller than the actual token length,
                        if len(confd) < token_length:
                            print(
                                "WARNING: less computeted confidences than token length! Fixing this with padding!"
                            )
                            confd = np.pad(confd,
                                           [0, token_length - len(confd)],
                                           mode='constant',
                                           constant_values=1.0)
                        elif len(confd) > token_length:
                            print(
                                "WARNING: more computeted confidences than token length! Fixing this with slicing!"
                            )
                            confd = confd[:token_length]

                    print(confd)
                    # print(key + "-utt%d-final" % utt, out["text"], flush=True)
                    if asr_client is not None:
                        asr_client.completeUtterance(
                            utterance=out["text"],
                            key=key + "-utt%d-part%d" % (utt, part),
                            confidences=confd)
                    offset += int(num_frames_decoded *
                                  decodable_opts.frame_subsampling_factor *
                                  feat_pipeline.frame_shift_in_seconds() *
                                  wav.samp_freq)
                    feat_pipeline.get_adaptation_state(adaptation_state)
                    feat_pipeline = OnlineNnetFeaturePipeline(feat_info)
                    feat_pipeline.set_adaptation_state(adaptation_state)
                    asr.set_input_pipeline(feat_pipeline)
                    asr.init_decoding()
                    sil_weighting = OnlineSilenceWeighting(
                        asr.transition_model,
                        feat_info.silence_weighting_config,
                        decodable_opts.frame_subsampling_factor)
                    remainder = data[offset:i + chunk_size]
                    feat_pipeline.accept_waveform(wav.samp_freq, remainder)
                    utt += 1
                    part = 1
                    prev_num_frames_decoded = 0
                elif num_frames_decoded > prev_num_frames_decoded:
                    prev_num_frames_decoded = num_frames_decoded
                    out = asr.get_partial_output()
                    # print(key + "-utt%d-part%d" % (utt, part),
                    #   out["text"], flush=True)
                    if asr_client is not None:
                        asr_client.partialUtterance(utterance=out["text"],
                                                    key=key + "-utt%d-part%d" %
                                                    (utt, part))
                    part += 1
        asr.finalize_decoding()
        out = asr.get_output()
        mbr = MinimumBayesRisk(out["lattice"])
        confd = mbr.get_one_best_confidences()
        print(out)
        # print(key + "-utt%d-final" % utt, out["text"], flush=True)
        if asr_client is not None:
            asr_client.completeUtterance(utterance=out["text"],
                                         key=key + "-utt%d-part%d" %
                                         (utt, part),
                                         confidences=confd)

        feat_pipeline.get_adaptation_state(adaptation_state)
Beispiel #8
0
import sys
import argparse

import numpy as np
from kaldi.util.table import SequentialWaveReader
from kaldi.matrix import Matrix, _matrix_ext

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='wav.scp to byte files, i.e.,'
                                                 'each line: uttid num_bytes')
    parser.add_argument('wav_rspecifier', type=str,
                        help='input wav.scp filename')
    parser.add_argument('byte_file', type=str,
                        help='input wav.scp filename')
    args, unk = parser.parse_known_args()
   
    wav_reader = SequentialWaveReader(args.wav_rspecifier)
    with open(args.byte_file, 'w') as bf:
        for uttid, wave in wav_reader:
            wave_data = _matrix_ext.matrix_to_numpy(wave.data())
            #has to be one channel 
            assert wave_data.shape[0] == 1
            bf.write('{} {}\n'.format(uttid, 2*len(wave_data[0].astype('int16'))))
Beispiel #9
0
def compute_vad(wav_rspecifier, feats_wspecifier, opts):
    """This function computes the vad based on ltsv features.

    The output is written in the file denoted by feats_wspecifier,
    and if the test_plot flag is set, it produces a plot.

    Args:
        wav_rspecifier: Kaldi specifier for reading wav files.
        feats_wspecifier:  Kaldi wpscifier for writing feature files.
        opts: Options. See main function for list of options

    Returns:
        True if computation was successful for at least one file.
        False otherwise.
    """

    num_utts, num_success = 0, 0
    with SequentialWaveReader(wav_rspecifier) as reader, \
         VectorWriter(feats_wspecifier) as writer:

        for num_utts, (key, wave) in enumerate(reader, 1):
            if wave.duration < opts.min_duration:
                print(
                    "File: {} is too short ({} sec): "
                    "producing no output.".format(key, wave.duration),
                    file=sys.stderr,
                )
                continue

            num_chan = wave.data().num_rows
            if opts.channel >= num_chan:
                print(
                    "File with id {} has {} channels but you specified "
                    "channel {}, producing no output.",
                    file=sys.stderr,
                )
                continue

            channel = 0 if opts.channel == -1 else opts.channel

            fr_length_samples = int(opts.frame_window * wave.samp_freq *
                                    (10**(-3)))
            fr_shift_samples = int(opts.frame_shift * wave.samp_freq *
                                   (10**(-3)))

            assert opts.nfft >= fr_length_samples

            wav_data = np.squeeze(wave.data()[channel].numpy())

            sample_freqs, segment_times, spec = signal.spectrogram(
                wav_data,
                fs=wave.samp_freq,
                nperseg=fr_length_samples,
                nfft=opts.nfft,
                noverlap=fr_length_samples - fr_shift_samples,
                scaling="spectrum",
                mode="psd",
            )

            specT = np.transpose(spec)

            spect_n = ARMA.ApplyARMA(specT, opts.arma_order)

            ltsv_f = LTSV.ApplyLTSV(
                spect_n,
                opts.ltsv_ctx_window,
                opts.threshold,
                opts.slope,
                opts.sigmoid_scale,
            )

            vad_feat = DCTF.ApplyDCT(opts.dct_num_cep, opts.dct_ctx_window,
                                     ltsv_f)

            if opts.test_plot:
                show_plot(
                    key,
                    segment_times,
                    sample_freqs,
                    spec,
                    wave.duration,
                    wav_data,
                    vad_feat,
                )

            writer[key] = Vector(vad_feat)
            num_success += 1

            if num_utts % 10 == 0:
                print("Processed {} utterances".format(num_utts),
                      file=sys.stderr)

    print(
        "Done {} out of {} utterances".format(num_success, num_utts),
        file=sys.stderr,
    )

    return num_success != 0
Beispiel #10
0
    print(key, out["text"], flush=True)

print("-" * 80, flush=True)


# Define feature pipeline in code
def make_feat_pipeline(base, opts=DeltaFeaturesOptions()):
    def feat_pipeline(wav):
        feats = base.compute_features(wav.data()[0], wav.samp_freq, 1.0)
        cmvn = Cmvn(base.dim())
        cmvn.accumulate(feats)
        cmvn.apply(feats)
        return compute_deltas(opts, feats)

    return feat_pipeline


frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = 16000
frame_opts.allow_downsample = True
mfcc_opts = MfccOptions()
mfcc_opts.use_energy = False
mfcc_opts.frame_opts = frame_opts
feat_pipeline = make_feat_pipeline(Mfcc(mfcc_opts))

# Decode
for key, wav in SequentialWaveReader("scp:wav.scp"):
    feats = feat_pipeline(wav)
    out = asr.decode(feats)
    print(key, out["text"], flush=True)
Beispiel #11
0
def compute_vad(wav_rspecifier, feats_wspecifier, opts):
    """This function computes the vad based on ltsv features.
  The output is written in the file denoted by feats_wspecifier,
  and if the test_plot flaf is set, it produces a plot.

  Args:
      wav_rspecifier: An ark or scp file as in Kaldi, that contains the input audio
      feats_wspecifier:  An ark or scp file as in Kaldi, that contains the input audio
      opts: Options. See main function for list of options
 
  Returns:
      The number of successful trials.
  """

    num_utts, num_success = 0, 0
    with SequentialWaveReader(wav_rspecifier) as reader, \
           VectorWriter(feats_wspecifier) as writer:

        for num_utts, (key, wave) in enumerate(reader, 1):
            if wave.duration < opts.min_duration:
                print("File: {} is too short ({} sec): producing no output.".
                      format(key, wave.duration),
                      file=sys.stderr)
                continue

            num_chan = wave.data().num_rows
            if opts.channel >= num_chan:
                print(
                    "File with id {} has {} channels but you specified "
                    "channel {}, producing no output.",
                    file=sys.stderr)
                continue
            channel = 0 if opts.channel == -1 else opts.channel

            fr_length_samples = int(opts.frame_window * wave.samp_freq *
                                    (10**(-3)))
            fr_shift_samples = int(opts.frame_shift * wave.samp_freq *
                                   (10**(-3)))

            try:

                wav_data = np.squeeze(wave.data()[channel].numpy())

                sample_freqs, segment_times, spec = signal.spectrogram(
                    wav_data,
                    fs=wave.samp_freq,
                    nperseg=fr_length_samples,
                    nfft=opts.nfft,
                    noverlap=fr_length_samples - fr_shift_samples,
                    scaling='spectrum',
                    mode='psd')

                specT = np.transpose(spec)

                spect_n = ARMA.ApplyARMA(specT, opts.arma_order)

                ltsv_f = LTSV.ApplyLTSV(spect_n, opts.ltsv_ctx_window,
                                        opts.threshold, opts.slope,
                                        opts.sigmoid_scale)

                vad_feat = DCTF.ApplyDCT(opts.dct_num_cep, opts.dct_ctx_window,
                                         ltsv_f)

                feats = Vector(vad_feat)

                if opts.test_plot:
                    show_plot(segment_times, sample_freqs, spec, wave,
                              wav_data, vad_feat)

            except:
                print("Failed to compute features for utterance",
                      key,
                      file=sys.stderr)
                continue

            writer[key] = feats
            num_success += 1

            if num_utts % 10 == 0:
                print("Processed {} utterances".format(num_utts),
                      file=sys.stderr)

    print("Done {} out of {} utterances".format(num_success, num_utts),
          file=sys.stderr)

    return num_success != 0
Beispiel #12
0
def compute_mfcc_feats(wav_rspecifier, feats_wspecifier, opts, mfcc_opts):
    mfcc = Mfcc(mfcc_opts)

    if opts.vtln_map:
        vtln_map_reader = RandomAccessFloatReaderMapped(
            opts.vtln_map, opts.utt2spk)
    elif opts.utt2spk:
        print("utt2spk option is needed only if vtln-map option is specified.",
              file=sys.stderr)

    num_utts, num_success = 0, 0
    with SequentialWaveReader(wav_rspecifier) as reader, \
         MatrixWriter(feats_wspecifier) as writer:
        for num_utts, (key, wave) in enumerate(reader, 1):
            if wave.duration < opts.min_duration:
                print("File: {} is too short ({} sec): producing no output.".
                      format(key, wave.duration),
                      file=sys.stderr)
                continue

            num_chan = wave.data().num_rows
            if opts.channel >= num_chan:
                print(
                    "File with id {} has {} channels but you specified "
                    "channel {}, producing no output.",
                    file=sys.stderr)
                continue
            channel = 0 if opts.channel == -1 else opts.channel

            if opts.vtln_map:
                if key not in vtln_map_reader:
                    print("No vtln-map entry for utterance-id (or speaker-id)",
                          key,
                          file=sys.stderr)
                    continue
                vtln_warp = vtln_map_reader[key]
            else:
                vtln_warp = opts.vtln_warp

            try:
                feats = mfcc.compute_features(wave.data()[channel],
                                              wave.samp_freq, vtln_warp)
            except:
                print("Failed to compute features for utterance",
                      key,
                      file=sys.stderr)
                continue

            if opts.subtract_mean:
                mean = Vector(feats.num_cols)
                mean.add_row_sum_mat_(1.0, feats)
                mean.scale_(1.0 / feats.num_rows)
                for i in range(feats.num_rows):
                    feats[i].add_vec_(-1.0, mean)

            writer[key] = feats
            num_success += 1

            if num_utts % 10 == 0:
                print("Processed {} utterances".format(num_utts),
                      file=sys.stderr)

    print("Done {} out of {} utterances".format(num_success, num_utts),
          file=sys.stderr)

    if opts.vtln_map:
        vtln_map_reader.close()

    return num_success != 0
        for text_path in text_pathes
    ])
    audio_transcripts.sort_values(by=0)
else:
    audio_transcripts = pd.concat([
        pd.read_csv(text_path, header=None, engine='python')
        for text_path in text_pathes
    ])
    audio_transcripts.sort_values(by=0)
    audio_transcripts = audio_transcripts[0].str.split(" ", 1, expand=True)
audio_transcripts[1] = audio_transcripts[1].str.lower()
audio_transcripts = audio_transcripts.set_index(0)[1].to_dict()

# Decode (whole utterance)
num_of_audiofiles = 0
for key, wav in SequentialWaveReader("scp:" + scp_path):
    feat_pipeline = OnlineNnetFeaturePipeline(feat_info)
    asr.set_input_pipeline(feat_pipeline)
    feat_pipeline.accept_waveform(wav.samp_freq, wav.data()[0])
    feat_pipeline.input_finished()

    audio_path = key
    try:
        audio, fs = sf.read(audio_path, dtype='int16')
    except:
        if VERBOSE:
            print("# WARNING :: Audio File" + audio_path + " not readable.\n")
        log_file.write("# WARNING :: Audio File " + audio_path +
                       " not readable.\n")
        continue
    audio_len = len(audio) / fs
def compute_mfcc_feats(wav_rspecifier, feats_wspecifier, opts, mfcc_opts):
    mfcc = Mfcc(mfcc_opts)

    # Shift by label window length so that feats align
    lab_window_len_sample = int(
        (opts.sampling_rate * opts.label_window_length) / 1000)
    lab_window_shift_sample = int(
        (opts.sampling_rate * opts.label_window_shift) / 1000)
    sig_window_len_sample = int(
        (opts.sampling_rate * opts.signal_window_length) / 1000)

    num_utts, num_success = 0, 0
    with SequentialWaveReader(wav_rspecifier) as reader, \
         MatrixWriter(feats_wspecifier) as writer:
        for num_utts, (key, wave) in enumerate(reader, 1):
            if wave.duration < opts.min_duration:
                print("File: {} is too short ({} sec): producing no output.".
                      format(key, wave.duration),
                      file=sys.stderr)
                continue

            num_chan = wave.data().num_rows
            if opts.channel >= num_chan:
                print(
                    "File with id {} has {} channels but you specified "
                    "channel {}, producing no output.",
                    file=sys.stderr)
                continue
            channel = 0 if opts.channel == -1 else opts.channel

            try:
                # Move signal from integers to floats
                signal = wave.data()[channel].numpy()
                signal = signal.astype(float) / 2**15  # 32768  # int to float
                signal /= np.max(np.abs(signal))  # normalise

                # Extract windows
                feats = extract_windows(signal, sig_window_len_sample,
                                        lab_window_len_sample,
                                        lab_window_shift_sample)
            except:
                print("Failed to compute features for utterance",
                      key,
                      file=sys.stderr)
                continue

            if opts.subtract_mean:
                mean = Vector(feats.num_cols)
                mean.add_row_sum_mat_(1.0, feats)
                mean.scale_(1.0 / feats.num_rows)
                for i in range(feats.num_rows):
                    feats[i].add_vec_(-1.0, mean)

            writer[key] = feats
            num_success += 1

            if num_utts % 10 == 0:
                print("Processed {} utterances".format(num_utts),
                      file=sys.stderr)

    print("Done {} out of {} utterances".format(num_success, num_utts),
          file=sys.stderr)

    return num_success != 0
Beispiel #15
0
# Define the decodable wrapper: (features, acoustic_scale) -> decodable
def make_decodable_wrapper(trans_model, acoustic_model):
    def decodable_wrapper(features, acoustic_scale):
        return DecodableAmDiagGmmScaled(acoustic_model, trans_model, features,
                                        acoustic_scale)

    return decodable_wrapper


decodable_wrapper = make_decodable_wrapper(trans_model, acoustic_model)

# Define the decoder
decoding_graph = read_fst_kaldi(
    "/home/dogan/tools/pykaldi/egs/models/wsj/HCLG.fst")
decoder_opts = FasterDecoderOptions()
decoder_opts.beam = 13
decoder_opts.max_active = 7000
decoder = FasterDecoder(decoding_graph, decoder_opts)

# Define the recognizer
symbols = SymbolTable.read_text(
    "/home/dogan/tools/pykaldi/egs/models/wsj/words.txt")
asr = Recognizer(decoder, decodable_wrapper, symbols)

# Decode wave files
for key, wav in SequentialWaveReader(
        "scp:/home/dogan/tools/pykaldi/egs/decoder/test2.scp"):
    feats = feat_pipeline(wav)
    out = asr.decode(feats)
    print(key, out["text"], flush=True)