示例#1
0
def load_model(config_file,
               online_config,
               models_path='models/',
               beam_size=10,
               frames_per_chunk=50):
    # Read YAML file
    with open(config_file, 'r') as stream:
        model_yaml = yaml.safe_load(stream)

    decoder_yaml_opts = model_yaml['decoder']

    print(decoder_yaml_opts)

    feat_opts = OnlineNnetFeaturePipelineConfig()
    endpoint_opts = OnlineEndpointConfig()

    if not os.path.isfile(online_config):
        print(online_config +
              ' does not exists. Trying to create it from yaml file settings.')
        print(
            'See also online_config_options.info.txt for what possible settings are.'
        )
        with open(online_config, 'w') as online_config_file:
            online_config_file.write("--add_pitch=False\n")
            online_config_file.write("--mfcc_config=" + models_path +
                                     decoder_yaml_opts['mfcc-config'] + "\n")
            online_config_file.write("--feature_type=mfcc\n")
            online_config_file.write(
                "--ivector_extraction_config=" + models_path +
                decoder_yaml_opts['ivector-extraction-config'] + '\n')
            online_config_file.write(
                "--endpoint.silence-phones=" +
                decoder_yaml_opts['endpoint-silence-phones'] + '\n')
    else:
        print("Loading online conf from:", online_config)

    po = ParseOptions("")
    feat_opts.register(po)
    endpoint_opts.register(po)
    po.read_config_file(online_config)
    feat_info = OnlineNnetFeaturePipelineInfo.from_config(feat_opts)

    # Construct recognizer
    decoder_opts = LatticeFasterDecoderOptions()
    decoder_opts.beam = beam_size
    decoder_opts.max_active = 7000
    decodable_opts = NnetSimpleLoopedComputationOptions()
    decodable_opts.acoustic_scale = 1.0
    decodable_opts.frame_subsampling_factor = 3
    decodable_opts.frames_per_chunk = frames_per_chunk
    asr = NnetLatticeFasterOnlineRecognizer.from_files(
        models_path + decoder_yaml_opts["model"],
        models_path + decoder_yaml_opts["fst"],
        models_path + decoder_yaml_opts["word-syms"],
        decoder_opts=decoder_opts,
        decodable_opts=decodable_opts,
        endpoint_opts=endpoint_opts)

    return asr, feat_info, decodable_opts
示例#2
0
# Define online feature pipeline
feat_opts = OnlineNnetFeaturePipelineConfig()
endpoint_opts = OnlineEndpointConfig()
po = ParseOptions("")
feat_opts.register(po)
endpoint_opts.register(po)
po.read_config_file("online.conf")
feat_info = OnlineNnetFeaturePipelineInfo.from_config(feat_opts)

# Construct recognizer
decoder_opts = LatticeFasterDecoderOptions()
decoder_opts.beam = 13
decoder_opts.max_active = 7000
decodable_opts = NnetSimpleLoopedComputationOptions()
decodable_opts.acoustic_scale = 1.0
decodable_opts.frame_subsampling_factor = 3
decodable_opts.frames_per_chunk = 150
asr = NnetLatticeFasterOnlineRecognizer.from_files(
    "final.mdl",
    "HCLG.fst",
    "words.txt",
    decoder_opts=decoder_opts,
    decodable_opts=decodable_opts,
    endpoint_opts=endpoint_opts)

# Decode (whole utterance)
for key, wav in SequentialWaveReader("scp:wav.scp"):
    feat_pipeline = OnlineNnetFeaturePipeline(feat_info)
    asr.set_input_pipeline(feat_pipeline)
    feat_pipeline.accept_waveform(wav.samp_freq, wav.data()[0])
    feat_pipeline.input_finished()