def infer(): tensor_global_step = tf.train.get_or_create_global_step() model_infer = args.Model( tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, is_train=False, args=args) dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev saver = tf.train.Saver(max_to_keep=40) size_variables() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) saver.restore(sess, checkpoint) total_cer_dist = 0 total_cer_len = 0 total_wer_dist = 0 total_wer_len = 0 with open(args.dir_model.name+'_decode.txt', 'w') as fw: for sample in tqdm(dataset_dev): if not sample: continue dict_feed = {model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model_infer.list_pl[1]: np.array([len(sample['feature'])])} sample_id, shape_batch, _ = sess.run(model_infer.list_run, feed_dict=dict_feed) # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed) res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1) # align_txt = array2text(alignment[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1) ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1) list_res_char = list(res_txt) list_ref_char = list(ref_txt) list_res_word = res_txt.split() list_ref_word = ref_txt.split() cer_dist = ed.eval(list_res_char, list_ref_char) cer_len = len(list_ref_char) wer_dist = ed.eval(list_res_word, list_ref_word) wer_len = len(list_ref_word) total_cer_dist += cer_dist total_cer_len += cer_len total_wer_dist += wer_dist total_wer_len += wer_len if cer_len == 0: cer_len = 1000 wer_len = 1000 if wer_dist/wer_len > 0: fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format(sample['id'], res_txt, ref_txt)) logging.info('current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}'.format( cer_dist/cer_len, wer_dist/wer_len, total_cer_dist/total_cer_len, total_wer_dist/total_wer_len)) logging.info('dev CER {:.3f}: WER: {:.3f}'.format(total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
def decode_test(step, sample, model, sess, unit, args): # sample = dataset_dev[0] dict_feed = { model.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model.list_pl[1]: np.array([len(sample['feature'])]) } (decoded_ctc, decoded), shape_sample, _ = sess.run(model.list_run, feed_dict=dict_feed) res_ctc_txt = array2text(decoded_ctc[0], unit, args.idx2phone, args.phone2idx) res_txt = array2text(decoded[0], unit, args.idx2token, args.token2idx) ref_txt = array2text(sample['label'], unit, args.idx2token, args.token2idx) logging.warning('length: {}, \nres_ctc: \n{}\nres: \n{}\nref: \n{}'.format( shape_sample[1], res_ctc_txt, res_txt, ref_txt))
def showing_csv_data(): from dataProcessing.dataHelper import ASR_csv_DataSet dataset_train = ASR_csv_DataSet( list_files=[args.dirs.train.data], args=args, _shuffle=False, transform=False) ref_txt = array2text(dataset_train[0]['label'], args.data.unit, args.idx2token) print(ref_txt)
def decode_test(step, sample, model, sess, unit, idx2token, eos_idx=None, min_idx=0, max_idx=None): # sample = dataset_dev[0] dict_feed = { model.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model.list_pl[1]: np.array([len(sample['feature'])]) } sampled_id, shape_sample, _ = sess.run(model.list_run, feed_dict=dict_feed) res_txt = array2text(sampled_id[0], unit, idx2token, eos_idx, min_idx, max_idx) ref_txt = array2text(sample['label'], unit, idx2token, eos_idx, min_idx, max_idx) logging.warning('length: {}, res: \n{}\nref: \n{}'.format( shape_sample[1], res_txt, ref_txt))
def infer_lm(): tensor_global_step = tf.train.get_or_create_global_step() dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev model_lm = args.Model_LM(tensor_global_step, training=False, args=args.args_lm) args.lm_obj = model_lm saver_lm = tf.train.Saver(model_lm.variables()) args.top_scope = tf.get_variable_scope() # top-level scope args.lm_scope = model_lm.decoder.scope model_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) saver = tf.train.Saver(model_infer.variables()) size_variables() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) checkpoint_lm = tf.train.latest_checkpoint(args.dirs.lm_checkpoint) saver.restore(sess, checkpoint) saver_lm.restore(sess, checkpoint_lm) total_cer_dist = 0 total_cer_len = 0 total_wer_dist = 0 total_wer_len = 0 with open(args.dir_model.name + '_decode.txt', 'w') as fw: # with open('/mnt/lustre/xushuang/easton/projects/asr-tf/exp/aishell/lm_acc.txt', 'w') as fw: for sample in tqdm(dataset_dev): if not sample: continue dict_feed = { model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model_infer.list_pl[1]: np.array([len(sample['feature'])]) } sample_id, shape_batch, beam_decoded = sess.run( model_infer.list_run, feed_dict=dict_feed) # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed) res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output - 1) ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output - 1) list_res_char = list(res_txt) list_ref_char = list(ref_txt) list_res_word = res_txt.split() list_ref_word = ref_txt.split() cer_dist = ed.eval(list_res_char, list_ref_char) cer_len = len(list_ref_char) wer_dist = ed.eval(list_res_word, list_ref_word) wer_len = len(list_ref_word) total_cer_dist += cer_dist total_cer_len += cer_len total_wer_dist += wer_dist total_wer_len += wer_len if cer_len == 0: cer_len = 1000 wer_len = 1000 if wer_dist / wer_len > 0: print('ref ', ref_txt) for i, decoded, score, rerank_score in zip( range(10), beam_decoded[0][0], beam_decoded[1][0], beam_decoded[2][0]): candidate = array2text(decoded, args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output - 1) print('res', i, candidate, score, rerank_score) fw.write('res: {}; ref: {}\n'.format( candidate, ref_txt)) fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format( sample['id'], res_txt, ref_txt)) logging.info( 'current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}' .format(cer_dist / cer_len, wer_dist / wer_len, total_cer_dist / total_cer_len, total_wer_dist / total_wer_len)) logging.info('dev CER {:.3f}: WER: {:.3f}'.format( total_cer_dist / total_cer_len, total_wer_dist / total_wer_len))
def infer(): tensor_global_step = tf.train.get_or_create_global_step() model_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) dataset_dev = args.dataset_test saver = tf.train.Saver(max_to_keep=1) size_variables() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: saver.restore(sess, args.dirs.checkpoint) total_cer_dist = 0 total_cer_len = 0 total_wer_dist = 0 total_wer_len = 0 with open('outputs/decoded.txt', 'w') as fw: for i, sample in enumerate(dataset_dev): if not sample: continue dict_feed = { model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model_infer.list_pl[1]: np.array([len(sample['feature'])]) } sample_id, shape_batch, _ = sess.run(model_infer.list_run, feed_dict=dict_feed) # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed) res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, args.token2idx) ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, args.token2idx) list_res_char = list(res_txt) list_ref_char = list(ref_txt) list_res_word = res_txt.split() list_ref_word = ref_txt.split() cer_dist = ed.eval(list_res_char, list_ref_char) cer_len = len(list_ref_char) wer_dist = ed.eval(list_res_word, list_ref_word) wer_len = len(list_ref_word) res_len = len(list_res_word) total_cer_dist += cer_dist total_cer_len += cer_len total_wer_dist += wer_dist total_wer_len += wer_len if cer_len == 0: cer_len = 1000 wer_len = 1000 if wer_dist / wer_len > 0: fw.write('uttid:\t{} \nres:\t{}\nref:\t{}\n\n'.format( sample['uttid'], res_txt, ref_txt)) sys.stdout.write( '\rcurrent cer: {:.3f}, wer: {:.3f} res/ref: {:.3f};\tall cer {:.3f}, wer: {:.3f} {}/{} {:.2f}%' .format(cer_dist / cer_len, wer_dist / wer_len, res_len / wer_len, total_cer_dist / total_cer_len, total_wer_dist / total_wer_len, i, len(dataset_dev), i / len(dataset_dev) * 100)) sys.stdout.flush() logging.info('dev CER {:.3f}: WER: {:.3f}'.format( total_cer_dist / total_cer_len, total_wer_dist / total_wer_len))