コード例 #1
0
ファイル: main.py プロジェクト: eastonYi/eastonCode
def infer():
    tensor_global_step = tf.train.get_or_create_global_step()

    model_infer = args.Model(
        tensor_global_step,
        encoder=args.model.encoder.type,
        decoder=args.model.decoder.type,
        is_train=False,
        args=args)

    dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev

    saver = tf.train.Saver(max_to_keep=40)
    size_variables()

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
        saver.restore(sess, checkpoint)

        total_cer_dist = 0
        total_cer_len = 0
        total_wer_dist = 0
        total_wer_len = 0
        with open(args.dir_model.name+'_decode.txt', 'w') as fw:
            for sample in tqdm(dataset_dev):
                if not sample:
                    continue
                dict_feed = {model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0),
                             model_infer.list_pl[1]: np.array([len(sample['feature'])])}
                sample_id, shape_batch, _ = sess.run(model_infer.list_run, feed_dict=dict_feed)
                # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed)
                res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1)
                # align_txt = array2text(alignment[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1)
                ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1)

                list_res_char = list(res_txt)
                list_ref_char = list(ref_txt)
                list_res_word = res_txt.split()
                list_ref_word = ref_txt.split()
                cer_dist = ed.eval(list_res_char, list_ref_char)
                cer_len = len(list_ref_char)
                wer_dist = ed.eval(list_res_word, list_ref_word)
                wer_len = len(list_ref_word)
                total_cer_dist += cer_dist
                total_cer_len += cer_len
                total_wer_dist += wer_dist
                total_wer_len += wer_len
                if cer_len == 0:
                    cer_len = 1000
                    wer_len = 1000
                if wer_dist/wer_len > 0:
                    fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format(sample['id'], res_txt, ref_txt))
                logging.info('current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}'.format(
                    cer_dist/cer_len, wer_dist/wer_len, total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
        logging.info('dev CER {:.3f}:  WER: {:.3f}'.format(total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
コード例 #2
0
def decode_test(step, sample, model, sess, unit, args):
    # sample = dataset_dev[0]
    dict_feed = {
        model.list_pl[0]: np.expand_dims(sample['feature'], axis=0),
        model.list_pl[1]: np.array([len(sample['feature'])])
    }
    (decoded_ctc, decoded), shape_sample, _ = sess.run(model.list_run,
                                                       feed_dict=dict_feed)
    res_ctc_txt = array2text(decoded_ctc[0], unit, args.idx2phone,
                             args.phone2idx)
    res_txt = array2text(decoded[0], unit, args.idx2token, args.token2idx)
    ref_txt = array2text(sample['label'], unit, args.idx2token, args.token2idx)

    logging.warning('length: {}, \nres_ctc: \n{}\nres: \n{}\nref: \n{}'.format(
        shape_sample[1], res_ctc_txt, res_txt, ref_txt))
コード例 #3
0
ファイル: dataset.py プロジェクト: eastonYi/eastonCode
def showing_csv_data():
    from dataProcessing.dataHelper import ASR_csv_DataSet
    dataset_train = ASR_csv_DataSet(
        list_files=[args.dirs.train.data],
        args=args,
        _shuffle=False,
        transform=False)
    ref_txt = array2text(dataset_train[0]['label'], args.data.unit, args.idx2token)
    print(ref_txt)
コード例 #4
0
def decode_test(step,
                sample,
                model,
                sess,
                unit,
                idx2token,
                eos_idx=None,
                min_idx=0,
                max_idx=None):
    # sample = dataset_dev[0]
    dict_feed = {
        model.list_pl[0]: np.expand_dims(sample['feature'], axis=0),
        model.list_pl[1]: np.array([len(sample['feature'])])
    }
    sampled_id, shape_sample, _ = sess.run(model.list_run, feed_dict=dict_feed)

    res_txt = array2text(sampled_id[0], unit, idx2token, eos_idx, min_idx,
                         max_idx)
    ref_txt = array2text(sample['label'], unit, idx2token, eos_idx, min_idx,
                         max_idx)

    logging.warning('length: {}, res: \n{}\nref: \n{}'.format(
        shape_sample[1], res_txt, ref_txt))
コード例 #5
0
ファイル: main_clm.py プロジェクト: eastonYi/asr-tf1
def infer_lm():
    tensor_global_step = tf.train.get_or_create_global_step()
    dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev

    model_lm = args.Model_LM(tensor_global_step,
                             training=False,
                             args=args.args_lm)

    args.lm_obj = model_lm
    saver_lm = tf.train.Saver(model_lm.variables())

    args.top_scope = tf.get_variable_scope()  # top-level scope
    args.lm_scope = model_lm.decoder.scope

    model_infer = args.Model(tensor_global_step,
                             encoder=args.model.encoder.type,
                             decoder=args.model.decoder.type,
                             training=False,
                             args=args)

    saver = tf.train.Saver(model_infer.variables())

    size_variables()

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
        checkpoint_lm = tf.train.latest_checkpoint(args.dirs.lm_checkpoint)
        saver.restore(sess, checkpoint)
        saver_lm.restore(sess, checkpoint_lm)

        total_cer_dist = 0
        total_cer_len = 0
        total_wer_dist = 0
        total_wer_len = 0
        with open(args.dir_model.name + '_decode.txt', 'w') as fw:
            # with open('/mnt/lustre/xushuang/easton/projects/asr-tf/exp/aishell/lm_acc.txt', 'w') as fw:
            for sample in tqdm(dataset_dev):
                if not sample:
                    continue
                dict_feed = {
                    model_infer.list_pl[0]:
                    np.expand_dims(sample['feature'], axis=0),
                    model_infer.list_pl[1]:
                    np.array([len(sample['feature'])])
                }
                sample_id, shape_batch, beam_decoded = sess.run(
                    model_infer.list_run, feed_dict=dict_feed)
                # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed)
                res_txt = array2text(sample_id[0],
                                     args.data.unit,
                                     args.idx2token,
                                     min_idx=0,
                                     max_idx=args.dim_output - 1)
                ref_txt = array2text(sample['label'],
                                     args.data.unit,
                                     args.idx2token,
                                     min_idx=0,
                                     max_idx=args.dim_output - 1)

                list_res_char = list(res_txt)
                list_ref_char = list(ref_txt)
                list_res_word = res_txt.split()
                list_ref_word = ref_txt.split()
                cer_dist = ed.eval(list_res_char, list_ref_char)
                cer_len = len(list_ref_char)
                wer_dist = ed.eval(list_res_word, list_ref_word)
                wer_len = len(list_ref_word)
                total_cer_dist += cer_dist
                total_cer_len += cer_len
                total_wer_dist += wer_dist
                total_wer_len += wer_len
                if cer_len == 0:
                    cer_len = 1000
                    wer_len = 1000
                if wer_dist / wer_len > 0:
                    print('ref  ', ref_txt)
                    for i, decoded, score, rerank_score in zip(
                            range(10), beam_decoded[0][0], beam_decoded[1][0],
                            beam_decoded[2][0]):
                        candidate = array2text(decoded,
                                               args.data.unit,
                                               args.idx2token,
                                               min_idx=0,
                                               max_idx=args.dim_output - 1)
                        print('res', i, candidate, score, rerank_score)
                        fw.write('res: {}; ref: {}\n'.format(
                            candidate, ref_txt))
                    fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format(
                        sample['id'], res_txt, ref_txt))
                logging.info(
                    'current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}'
                    .format(cer_dist / cer_len, wer_dist / wer_len,
                            total_cer_dist / total_cer_len,
                            total_wer_dist / total_wer_len))
        logging.info('dev CER {:.3f}:  WER: {:.3f}'.format(
            total_cer_dist / total_cer_len, total_wer_dist / total_wer_len))
コード例 #6
0
def infer():
    tensor_global_step = tf.train.get_or_create_global_step()

    model_infer = args.Model(tensor_global_step,
                             encoder=args.model.encoder.type,
                             decoder=args.model.decoder.type,
                             training=False,
                             args=args)

    dataset_dev = args.dataset_test

    saver = tf.train.Saver(max_to_keep=1)
    size_variables()

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        saver.restore(sess, args.dirs.checkpoint)

        total_cer_dist = 0
        total_cer_len = 0
        total_wer_dist = 0
        total_wer_len = 0
        with open('outputs/decoded.txt', 'w') as fw:
            for i, sample in enumerate(dataset_dev):
                if not sample:
                    continue
                dict_feed = {
                    model_infer.list_pl[0]:
                    np.expand_dims(sample['feature'], axis=0),
                    model_infer.list_pl[1]:
                    np.array([len(sample['feature'])])
                }
                sample_id, shape_batch, _ = sess.run(model_infer.list_run,
                                                     feed_dict=dict_feed)
                # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed)
                res_txt = array2text(sample_id[0], args.data.unit,
                                     args.idx2token, args.token2idx)
                ref_txt = array2text(sample['label'], args.data.unit,
                                     args.idx2token, args.token2idx)

                list_res_char = list(res_txt)
                list_ref_char = list(ref_txt)
                list_res_word = res_txt.split()
                list_ref_word = ref_txt.split()

                cer_dist = ed.eval(list_res_char, list_ref_char)
                cer_len = len(list_ref_char)
                wer_dist = ed.eval(list_res_word, list_ref_word)
                wer_len = len(list_ref_word)
                res_len = len(list_res_word)
                total_cer_dist += cer_dist
                total_cer_len += cer_len
                total_wer_dist += wer_dist
                total_wer_len += wer_len
                if cer_len == 0:
                    cer_len = 1000
                    wer_len = 1000
                if wer_dist / wer_len > 0:
                    fw.write('uttid:\t{} \nres:\t{}\nref:\t{}\n\n'.format(
                        sample['uttid'], res_txt, ref_txt))
                sys.stdout.write(
                    '\rcurrent cer: {:.3f}, wer: {:.3f} res/ref: {:.3f};\tall cer {:.3f}, wer: {:.3f} {}/{} {:.2f}%'
                    .format(cer_dist / cer_len, wer_dist / wer_len,
                            res_len / wer_len, total_cer_dist / total_cer_len,
                            total_wer_dist / total_wer_len, i,
                            len(dataset_dev), i / len(dataset_dev) * 100))
                sys.stdout.flush()
        logging.info('dev CER {:.3f}:  WER: {:.3f}'.format(
            total_cer_dist / total_cer_len, total_wer_dist / total_wer_len))