Esempio n. 1
0
def infer():
    tensor_global_step = tf.train.get_or_create_global_step()

    model_infer = args.Model(
        tensor_global_step,
        encoder=args.model.encoder.type,
        decoder=args.model.decoder.type,
        is_train=False,
        args=args)

    dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev

    saver = tf.train.Saver(max_to_keep=40)
    size_variables()

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
        saver.restore(sess, checkpoint)

        total_cer_dist = 0
        total_cer_len = 0
        total_wer_dist = 0
        total_wer_len = 0
        with open(args.dir_model.name+'_decode.txt', 'w') as fw:
            for sample in tqdm(dataset_dev):
                if not sample:
                    continue
                dict_feed = {model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0),
                             model_infer.list_pl[1]: np.array([len(sample['feature'])])}
                sample_id, shape_batch, _ = sess.run(model_infer.list_run, feed_dict=dict_feed)
                # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed)
                res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1)
                # align_txt = array2text(alignment[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1)
                ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1)

                list_res_char = list(res_txt)
                list_ref_char = list(ref_txt)
                list_res_word = res_txt.split()
                list_ref_word = ref_txt.split()
                cer_dist = ed.eval(list_res_char, list_ref_char)
                cer_len = len(list_ref_char)
                wer_dist = ed.eval(list_res_word, list_ref_word)
                wer_len = len(list_ref_word)
                total_cer_dist += cer_dist
                total_cer_len += cer_len
                total_wer_dist += wer_dist
                total_wer_len += wer_len
                if cer_len == 0:
                    cer_len = 1000
                    wer_len = 1000
                if wer_dist/wer_len > 0:
                    fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format(sample['id'], res_txt, ref_txt))
                logging.info('current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}'.format(
                    cer_dist/cer_len, wer_dist/wer_len, total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
        logging.info('dev CER {:.3f}:  WER: {:.3f}'.format(total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
Esempio n. 2
0
def infer():
    tensor_global_step = tf.train.get_or_create_global_step()

    model_infer = args.Model(tensor_global_step, is_train=False, args=args)
    input_pl = tf.placeholder(tf.int32, [None, None])
    len_pl = tf.placeholder(tf.int32, [None])
    score_T, distribution_T = model_infer.score(input_pl, len_pl)
    # sampled_op, num_samples_op = model_infer.sample(max_length=50)

    size_variables()
    saver = tf.train.Saver(max_to_keep=40)

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.Session(config=config) as sess:
        checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
        saver.restore(sess, checkpoint)

        dev(args.dataset_test, model_infer, sess)
Esempio n. 3
0
def test():
    """
    containing sample test and score test
    """
    tensor_global_step = tf.train.get_or_create_global_step()

    model_infer = args.Model(tensor_global_step, is_train=False, args=args)
    sampled_op, num_samples_op = model_infer.sample(max_length=50)

    # pl_input = tf.placeholder([None, None])

    size_variables()

    saver = tf.train.Saver()

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.Session(config=config) as sess:
        checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
        saver.restore(sess, checkpoint)

        # samples = sess.run(sampled_op, feed_dict={num_samples_op: 10})
        # samples = array_idx2char(samples, args.idx2token, seperator='')
        # for i in samples:
        #     print(i)
        # res 0 就 想 我 的 时 候 自 然 会 需 要 面 包 -61.252304 -54.955883
        # res 1 就 像 我 饿 的 时 候 自 然 会 需 要 面 包 -65.39448 -60.87168
        # res 2 就 想 我 饿 的 时 候 自 然 会 需 要 面 包 -66.52325 -65.72158

        list_sents = ['郑 伟 电 视 剧 有 什 么', '郑 伟 电 视 剧 有 什 么 么']

        decoder_input, len_seq = array_char2idx(list_sents, args.token2idx,
                                                ' ')
        pad = np.ones([decoder_input.shape[0], 1],
                      dtype=decoder_input.dtype) * args.sos_idx
        decoder_input_sos = np.concatenate([pad, decoder_input], -1)
        score, distribution = model_infer.score(decoder_input_sos, len_seq)
        print(sess.run(score))
Esempio n. 4
0
def train():
    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFReader(args.dirs.train.tfdata, args=args)
    batch_train = dataReader_train.fentch_batch_bucket()

    feat, label = readTFRecord(args.dirs.dev.tfdata, args, _shuffle=False, transform=True)
    dataloader_dev = ASRDataLoader(args.dataset_dev, args, feat, label, batch_size=args.batch_size, num_loops=1)
    # feat, label = readTFRecord(args.dirs.train.tfdata, args, shuffle=False, transform=True)
    # dataloader_train = ASRDataLoader(args.dataset_train, args, feat, label, batch_size=args.batch_size, num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()

    model = args.Model(
        tensor_global_step,
        encoder=args.model.encoder.type,
        decoder=args.model.decoder.type,
        batch=batch_train,
        is_train=True,
        args=args)

    model_infer = args.Model(
        tensor_global_step,
        encoder=args.model.encoder.type,
        decoder=args.model.decoder.type,
        is_train=False,
        args=args)

    size_variables()
    start_time = datetime.now()
    checker = check_to_stop()

    saver = tf.train.Saver(max_to_keep=15)
    if args.dirs.lm_checkpoint:
        from tfTools.checkpointTools import list_variables

        list_lm_vars_pretrained = list_variables(args.dirs.lm_checkpoint)
        list_lm_vars = model.decoder.lm.variables

        dict_lm_vars = {}
        for var in list_lm_vars:
            if 'embedding' in var.name:
                for var_pre in list_lm_vars_pretrained:
                    if 'embedding' in var_pre[0]:
                        break
            else:
                name = var.name.split(model.decoder.lm.name)[1].split(':')[0]
                for var_pre in list_lm_vars_pretrained:
                    if name in var_pre[0]:
                        break
            # 'var_name_in_checkpoint': var_in_graph
            dict_lm_vars[var_pre[0]] = var

        saver_lm = tf.train.Saver(dict_lm_vars)

    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        if args.dirs.checkpoint_init:
            checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
            saver.restore(sess, checkpoint)

        elif args.dirs.lm_checkpoint:
            lm_checkpoint = tf.train.latest_checkpoint(args.dirs.lm_checkpoint)
            saver_lm.restore(sess, lm_checkpoint)

        dataloader_dev.sess = sess

        batch_time = time()
        num_processed = 0
        progress = 0
        while progress < args.num_epochs:
            global_step, lr = sess.run([tensor_global_step, model.learning_rate])
            loss, shape_batch, _, _ = sess.run(model.list_run)

            num_processed += shape_batch[0]
            used_time = time()-batch_time
            batch_time = time()
            progress = num_processed/args.data.train.size_dataset

            if global_step % 10 == 0:
                logging.info('loss: {:.3f}\tbatch: {} lr:{:.6f} time:{:.2f}s {:.3f}% step: {}'.format(
                              loss, shape_batch, lr, used_time, progress*100.0, global_step))
                summary.summary_scalar('loss', loss, global_step)
                summary.summary_scalar('lr', lr, global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=global_step, write_meta_graph=True)

            if global_step % args.dev_step == args.dev_step - 1:
                cer, wer = dev(
                    step=global_step,
                    dataloader=dataloader_dev,
                    model=model_infer,
                    sess=sess,
                    unit=args.data.unit,
                    idx2token=args.idx2token,
                    eos_idx=args.eos_idx,
                    min_idx=0,
                    max_idx=args.dim_output-1)
                summary.summary_scalar('dev_cer', cer, global_step)
                summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
                # decode_test(
                #     step=global_step,
                #     sample=args.dataset_test[10],
                #     model=model_infer,
                #     sess=sess,
                #     unit=args.data.unit,
                #     idx2token=args.idx2token,
                #     eos_idx=args.eos_idx,
                #     min_idx=3,
                #     max_idx=None)
                decode_test(
                    step=global_step,
                    sample=args.dataset_test[10],
                    model=model_infer,
                    sess=sess,
                    unit=args.data.unit,
                    idx2token=args.idx2token,
                    eos_idx=None,
                    min_idx=0,
                    max_idx=None)

            if args.num_steps and global_step > args.num_steps:
                sys.exit()

    logging.info('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
Esempio n. 5
0
def infer_lm():
    tensor_global_step = tf.train.get_or_create_global_step()
    dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev

    model_lm = args.Model_LM(
        tensor_global_step,
        is_train=False,
        args=args.args_lm)

    args.lm_obj = model_lm
    saver_lm = tf.train.Saver(model_lm.variables())

    args.top_scope = tf.get_variable_scope()   # top-level scope
    args.lm_scope = model_lm.decoder.scope

    model_infer = args.Model(
        tensor_global_step,
        encoder=args.model.encoder.type,
        decoder=args.model.decoder.type,
        is_train=False,
        args=args)

    saver = tf.train.Saver(model_infer.variables())

    size_variables()

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
        checkpoint_lm = tf.train.latest_checkpoint(args.dirs.lm_checkpoint)
        saver.restore(sess, checkpoint)
        saver_lm.restore(sess, checkpoint_lm)

        total_cer_dist = 0
        total_cer_len = 0
        total_wer_dist = 0
        total_wer_len = 0
        with open(args.dir_model.name+'_decode.txt', 'w') as fw:
        # with open('/mnt/lustre/xushuang/easton/projects/asr-tf/exp/aishell/lm_acc.txt', 'w') as fw:
            for sample in tqdm(dataset_dev):
                if not sample:
                    continue
                dict_feed = {model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0),
                             model_infer.list_pl[1]: np.array([len(sample['feature'])])}
                sample_id, shape_batch, beam_decoded = sess.run(model_infer.list_run, feed_dict=dict_feed)
                # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed)
                res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1)
                ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1)

                list_res_char = list(res_txt)
                list_ref_char = list(ref_txt)
                list_res_word = res_txt.split()
                list_ref_word = ref_txt.split()
                cer_dist = ed.eval(list_res_char, list_ref_char)
                cer_len = len(list_ref_char)
                wer_dist = ed.eval(list_res_word, list_ref_word)
                wer_len = len(list_ref_word)
                total_cer_dist += cer_dist
                total_cer_len += cer_len
                total_wer_dist += wer_dist
                total_wer_len += wer_len
                if cer_len == 0:
                    cer_len = 1000
                    wer_len = 1000
                if wer_dist/wer_len > 0:
                    print('ref  ' , ref_txt)
                    for i, decoded, score, rerank_score in zip(range(10), beam_decoded[0][0], beam_decoded[1][0], beam_decoded[2][0]):
                        candidate = array2text(decoded, args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1)
                        print('res' ,i , candidate, score, rerank_score)
                        fw.write('res: {}; ref: {}\n'.format(candidate, ref_txt))
                    fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format(sample['id'], res_txt, ref_txt))
                logging.info('current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}'.format(
                    cer_dist/cer_len, wer_dist/wer_len, total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
        logging.info('dev CER {:.3f}:  WER: {:.3f}'.format(total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
Esempio n. 6
0
def train():
    dataloader_train = LMDataLoader(args.dataset_train, 999999, args)
    # dataloader_train = PTBDataLoader(args.dataset_train, 80, args)
    tensor_global_step = tf.train.get_or_create_global_step()

    model = args.Model(tensor_global_step, is_train=True, args=args)
    model_infer = args.Model(tensor_global_step, is_train=False, args=args)
    input_pl = tf.placeholder(tf.int32, [None, None])
    len_pl = tf.placeholder(tf.int32, [None])
    score_T, distribution_T = model.score(input_pl, len_pl)

    size_variables()
    start_time = datetime.now()
    checker = check_to_stop()

    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        if args.dirs.checkpoint_init:
            checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init)
            saver.restore(sess, checkpoint)

        batch_time = time()
        num_processed = 0
        progress = 0
        for x, y, len_x, len_y in dataloader_train:
            global_step, lr = sess.run(
                [tensor_global_step, model.learning_rate])
            feed_dict = {
                model.list_pl[0]: x,
                model.list_pl[1]: y,
                model.list_pl[2]: len_x,
                model.list_pl[3]: len_y
            }
            loss, shape_batch, _ = sess.run(model.list_run,
                                            feed_dict=feed_dict)

            if global_step % 10 == 0:
                num_tokens = np.sum(len_x)
                ppl = np.exp(loss / num_tokens)
                summary.summary_scalar('loss', loss, global_step)
                summary.summary_scalar('ppl', ppl, global_step)
                summary.summary_scalar('lr', lr, global_step)

                num_processed += shape_batch[0]
                used_time = time() - batch_time
                batch_time = time()
                progress = num_processed / args.dataset_train.size_dataset
                logging.info(
                    'ppl: {:.3f}\tshape_batch: {} lr:{:.6f} time:{:.2f}s {:.3f}% step: {}'
                    .format(ppl, shape_batch, lr, used_time, progress * 100.0,
                            global_step))

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess),
                           str(args.dir_checkpoint / 'model'),
                           global_step=global_step,
                           write_meta_graph=False)

            if global_step % args.dev_step == args.dev_step - 1:
                ppl_dev = dev(args.dataset_dev, model_infer, sess)
                summary.summary_scalar('ppl_dev', ppl_dev, global_step)
                # accuracy = dev_external('/mnt/lustre/xushuang/easton/projects/asr-tf/exp/aishell/lm_acc.txt', model_infer, input_pl, len_pl, score_T, distribution_T, sess)
                # summary.summary_scalar('accuracy', accuracy, global_step)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))