def decode(jsonfile): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile( jsonfile, 0) checkpointer.restore_from_best() solver = DecoderSolver(model, config=p.decode_config) dataset_builder = dataset_builder.load_csv( p.test_csv).compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1))
def decode(jsonfile): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn": _, lm_model, _, lm_checkpointer = build_model_from_jsonfile(p.decode_config['lm_path']) lm_checkpointer.restore_from_best() checkpointer.restore_from_best() solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config) dataset_builder = dataset_builder.compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1))
def decode(jsonfile, n=1, log_file=None): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) lm_model = None if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn": _, lm_model, _, lm_checkpointer = build_model_from_jsonfile( p.decode_config['lm_path']) lm_checkpointer.restore_from_best() if not os.path.exists(log_file): raise IOError('The file is not exist!') checkpoint_wer_dict = {} for line in open(log_file): if 'epoch:' in line: print(line) splits = line.strip().split('\t') epoch = int(splits[0].split(' ')[-1]) ctc_acc = float(splits[-1].split(' ')[-1]) checkpoint_wer_dict[epoch] = ctc_acc checkpoint_wer_dict = { k: v for k, v in sorted(checkpoint_wer_dict.items(), key=lambda item: item[1], reverse=True) } ckpt_index_list = list(checkpoint_wer_dict.keys())[0:n] print('best_wer_checkpoint: ') print(ckpt_index_list) ckpt_v_list = [] #restore v from ckpts for idx in ckpt_index_list: ckpt_path = p.ckpt + 'ckpt-' + str(idx + 1) checkpointer.restore(ckpt_path) #current variables will be updated var_list = [] for i in model.trainable_variables: v = tf.constant(i.value()) var_list.append(v) ckpt_v_list.append(var_list) #compute average, and assign to current variables for i in range(len(model.trainable_variables)): v = [ tf.expand_dims(ckpt_v_list[j][i], [0]) for j in range(len(ckpt_v_list)) ] v = tf.reduce_mean(tf.concat(v, axis=0), axis=0) model.trainable_variables[i].assign(v) solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder]( p.testset_config) dataset_builder = dataset_builder.compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1), dataset_builder)
def decode(jsonfile, rank_size=1, rank=0): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer = build_model_from_jsonfile(jsonfile) avg_num = 1 if 'model_avg_num' not in p.decode_config else p.decode_config[ 'model_avg_num'] checkpointer.compute_nbest_avg(avg_num) lm_model = None if 'lm_type' in p.decode_config and p.decode_config['lm_type'] == "rnn": _, lm_model, _, lm_checkpointer = build_model_from_jsonfile( p.decode_config['lm_path']) lm_checkpointer.restore_from_best() solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder]( p.testset_config) dataset_builder.shard(rank_size, rank) logging.info("shard result: %d" % len(dataset_builder)) solver.decode(dataset_builder.as_dataset(batch_size=1), rank_size=rank_size)
for i in model.trainable_variables: v = tf.constant(i.value()) var_list.append(v) ckpt_v_list.append(var_list) #compute average, and assign to current variables for i in range(len(model.trainable_variables)): v = [tf.expand_dims(ckpt_v_list[j][i],[0]) for j in range(len(ckpt_v_list))] v = tf.reduce_mean(tf.concat(v,axis=0),axis=0) model.trainable_variables[i].assign(v) solver = DecoderSolver(model, config=p.decode_config, lm_model=lm_model) assert p.testset_config is not None dataset_builder = SUPPORTED_DATASET_BUILDER[p.dataset_builder](p.testset_config) dataset_builder = dataset_builder.compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1)) if __name__ == "__main__": logging.set_verbosity(logging.INFO) if len(sys.argv) < 2: logging.warning('Usage: python {} config_json_file'.format(sys.argv[0])) sys.exit() tf.random.set_seed(1) jsonfile = sys.argv[1] with open(jsonfile) as file: config = json.load(file) p = parse_config(config) DecoderSolver.initialize_devices(p.solver_gpu) decode(jsonfile, n=5, log_file='nohup.out')
import sys import json import tensorflow as tf from absl import logging from athena import DecoderSolver from athena.main import (parse_config, build_model_from_jsonfile) def decode(jsonfile): """ entry point for model decoding, do some preparation work """ p, model, _, checkpointer, dataset_builder = build_model_from_jsonfile( jsonfile, 0) checkpointer.restore_from_best() solver = DecoderSolver(model, config=p.decode_config) dataset_builder = dataset_builder.load_csv( p.test_csv).compute_cmvn_if_necessary(True) solver.decode(dataset_builder.as_dataset(batch_size=1)) if __name__ == "__main__": logging.set_verbosity(logging.INFO) tf.random.set_seed(1) JSON_FILE = sys.argv[1] CONFIG = None with open(JSON_FILE) as f: CONFIG = json.load(f) PARAMS = parse_config(CONFIG) DecoderSolver.initialize_devices(PARAMS.solver_gpu) decode(JSON_FILE)