Exemplo n.º 1
0
def decode(jsonfile):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile(
        jsonfile, 0)
    checkpointer.restore_from_best()
    solver = DecoderSolver(model, config=p.decode_config)
    dataset_builder = dataset_builder.load_csv(
        p.test_csv).compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1))
Exemplo n.º 2
0
def decode(jsonfile):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn":
        _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(p.decode_config['lm_path'])
        lm_checkpointer.restore_from_best()
    checkpointer.restore_from_best()
    solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config)
    dataset_builder = dataset_builder.compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1))
Exemplo n.º 3
0
def decode(jsonfile, n=1, log_file=None):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    lm_model = None
    if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn":
        _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(
            p.decode_config['lm_path'])
        lm_checkpointer.restore_from_best()
    if not os.path.exists(log_file):
        raise IOError('The file is not exist!')
    checkpoint_wer_dict = {}
    for line in open(log_file):
        if 'epoch:' in line:
            print(line)
            splits = line.strip().split('\t')
            epoch = int(splits[0].split(' ')[-1])
            ctc_acc = float(splits[-1].split(' ')[-1])
            checkpoint_wer_dict[epoch] = ctc_acc
    checkpoint_wer_dict = {
        k: v
        for k, v in sorted(checkpoint_wer_dict.items(),
                           key=lambda item: item[1],
                           reverse=True)
    }
    ckpt_index_list = list(checkpoint_wer_dict.keys())[0:n]
    print('best_wer_checkpoint: ')
    print(ckpt_index_list)
    ckpt_v_list = []
    #restore v from ckpts
    for idx in ckpt_index_list:
        ckpt_path = p.ckpt + 'ckpt-' + str(idx + 1)
        checkpointer.restore(ckpt_path)  #current variables will be updated
        var_list = []
        for i in model.trainable_variables:
            v = tf.constant(i.value())
            var_list.append(v)
        ckpt_v_list.append(var_list)
    #compute average, and assign to current variables
    for i in range(len(model.trainable_variables)):
        v = [
            tf.expand_dims(ckpt_v_list[j][i], [0])
            for j in range(len(ckpt_v_list))
        ]
        v = tf.reduce_mean(tf.concat(v, axis=0), axis=0)
        model.trainable_variables[i].assign(v)
    solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](
        p.testset_config)
    dataset_builder = dataset_builder.compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1), dataset_builder)
Exemplo n.º 4
0
def decode(jsonfile, rank_size=1, rank=0):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer = build_model_from_jsonfile(jsonfile)
    avg_num = 1 if 'model_avg_num' not in p.decode_config else p.decode_config[
        'model_avg_num']
    checkpointer.compute_nbest_avg(avg_num)
    lm_model = None
    if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn":
        _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(
            p.decode_config['lm_path'])
        lm_checkpointer.restore_from_best()

    solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](
        p.testset_config)
    dataset_builder.shard(rank_size, rank)
    logging.info("shard result: %d" % len(dataset_builder))
    solver.decode(dataset_builder.as_dataset(batch_size=1),
                  rank_size=rank_size)
Exemplo n.º 5
0
        for i in model.trainable_variables:
            v = tf.constant(i.value())
            var_list.append(v)
        ckpt_v_list.append(var_list)
    #compute average, and assign to current variables
    for i in range(len(model.trainable_variables)):
        v = [tf.expand_dims(ckpt_v_list[j][i],[0]) for j in range(len(ckpt_v_list))]
        v = tf.reduce_mean(tf.concat(v,axis=0),axis=0)
        model.trainable_variables[i].assign(v)
    solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model)
    assert p.testset_config is not None
    dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config)
    dataset_builder = dataset_builder.compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1))


if __name__ == "__main__":
    logging.set_verbosity(logging.INFO)
    if len(sys.argv) < 2:
        logging.warning('Usage: python {} config_json_file'.format(sys.argv[0]))
        sys.exit()
    tf.random.set_seed(1)

    jsonfile = sys.argv[1]
    with open(jsonfile) as file:
        config = json.load(file)
    p = parse_config(config)

    DecoderSolver.initialize_devices(p.solver_gpu)
    decode(jsonfile, n=5, log_file='nohup.out')
Exemplo n.º 6
0
import sys
import json
import tensorflow as tf
from absl import logging
from athena import DecoderSolver
from athena.main import (parse_config, build_model_from_jsonfile)


def decode(jsonfile):
    """ entry point for model decoding, do some preparation work """
    p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile(
        jsonfile, 0)
    checkpointer.restore_from_best()
    solver = DecoderSolver(model, config=p.decode_config)
    dataset_builder = dataset_builder.load_csv(
        p.test_csv).compute_cmvn_if_necessary(True)
    solver.decode(dataset_builder.as_dataset(batch_size=1))


if __name__ == "__main__":
    logging.set_verbosity(logging.INFO)
    tf.random.set_seed(1)

    JSON_FILE = sys.argv[1]
    CONFIG = None
    with open(JSON_FILE) as f:
        CONFIG = json.load(f)
    PARAMS = parse_config(CONFIG)
    DecoderSolver.initialize_devices(PARAMS.solver_gpu)
    decode(JSON_FILE)