def infer(): tensor_global_step = tf.train.get_or_create_global_step() model_infer = args.Model( tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, is_train=False, args=args) dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev saver = tf.train.Saver(max_to_keep=40) size_variables() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) saver.restore(sess, checkpoint) total_cer_dist = 0 total_cer_len = 0 total_wer_dist = 0 total_wer_len = 0 with open(args.dir_model.name+'_decode.txt', 'w') as fw: for sample in tqdm(dataset_dev): if not sample: continue dict_feed = {model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model_infer.list_pl[1]: np.array([len(sample['feature'])])} sample_id, shape_batch, _ = sess.run(model_infer.list_run, feed_dict=dict_feed) # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed) res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1) # align_txt = array2text(alignment[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1) ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1) list_res_char = list(res_txt) list_ref_char = list(ref_txt) list_res_word = res_txt.split() list_ref_word = ref_txt.split() cer_dist = ed.eval(list_res_char, list_ref_char) cer_len = len(list_ref_char) wer_dist = ed.eval(list_res_word, list_ref_word) wer_len = len(list_ref_word) total_cer_dist += cer_dist total_cer_len += cer_len total_wer_dist += wer_dist total_wer_len += wer_len if cer_len == 0: cer_len = 1000 wer_len = 1000 if wer_dist/wer_len > 0: fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format(sample['id'], res_txt, ref_txt)) logging.info('current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}'.format( cer_dist/cer_len, wer_dist/wer_len, total_cer_dist/total_cer_len, total_wer_dist/total_wer_len)) logging.info('dev CER {:.3f}: WER: {:.3f}'.format(total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
def infer(): tensor_global_step = tf.train.get_or_create_global_step() model_infer = args.Model(tensor_global_step, is_train=False, args=args) input_pl = tf.placeholder(tf.int32, [None, None]) len_pl = tf.placeholder(tf.int32, [None]) score_T, distribution_T = model_infer.score(input_pl, len_pl) # sampled_op, num_samples_op = model_infer.sample(max_length=50) size_variables() saver = tf.train.Saver(max_to_keep=40) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.Session(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) saver.restore(sess, checkpoint) dev(args.dataset_test, model_infer, sess)
def test(): """ containing sample test and score test """ tensor_global_step = tf.train.get_or_create_global_step() model_infer = args.Model(tensor_global_step, is_train=False, args=args) sampled_op, num_samples_op = model_infer.sample(max_length=50) # pl_input = tf.placeholder([None, None]) size_variables() saver = tf.train.Saver() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.Session(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) saver.restore(sess, checkpoint) # samples = sess.run(sampled_op, feed_dict={num_samples_op: 10}) # samples = array_idx2char(samples, args.idx2token, seperator='') # for i in samples: # print(i) # res 0 就 想 我 的 时 候 自 然 会 需 要 面 包 -61.252304 -54.955883 # res 1 就 像 我 饿 的 时 候 自 然 会 需 要 面 包 -65.39448 -60.87168 # res 2 就 想 我 饿 的 时 候 自 然 会 需 要 面 包 -66.52325 -65.72158 list_sents = ['郑 伟 电 视 剧 有 什 么', '郑 伟 电 视 剧 有 什 么 么'] decoder_input, len_seq = array_char2idx(list_sents, args.token2idx, ' ') pad = np.ones([decoder_input.shape[0], 1], dtype=decoder_input.dtype) * args.sos_idx decoder_input_sos = np.concatenate([pad, decoder_input], -1) score, distribution = model_infer.score(decoder_input_sos, len_seq) print(sess.run(score))
def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_batch_bucket() feat, label = readTFRecord(args.dirs.dev.tfdata, args, _shuffle=False, transform=True) dataloader_dev = ASRDataLoader(args.dataset_dev, args, feat, label, batch_size=args.batch_size, num_loops=1) # feat, label = readTFRecord(args.dirs.train.tfdata, args, shuffle=False, transform=True) # dataloader_train = ASRDataLoader(args.dataset_train, args, feat, label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() model = args.Model( tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, batch=batch_train, is_train=True, args=args) model_infer = args.Model( tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, is_train=False, args=args) size_variables() start_time = datetime.now() checker = check_to_stop() saver = tf.train.Saver(max_to_keep=15) if args.dirs.lm_checkpoint: from tfTools.checkpointTools import list_variables list_lm_vars_pretrained = list_variables(args.dirs.lm_checkpoint) list_lm_vars = model.decoder.lm.variables dict_lm_vars = {} for var in list_lm_vars: if 'embedding' in var.name: for var_pre in list_lm_vars_pretrained: if 'embedding' in var_pre[0]: break else: name = var.name.split(model.decoder.lm.name)[1].split(':')[0] for var_pre in list_lm_vars_pretrained: if name in var_pre[0]: break # 'var_name_in_checkpoint': var_in_graph dict_lm_vars[var_pre[0]] = var saver_lm = tf.train.Saver(dict_lm_vars) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: if args.dirs.checkpoint_init: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) saver.restore(sess, checkpoint) elif args.dirs.lm_checkpoint: lm_checkpoint = tf.train.latest_checkpoint(args.dirs.lm_checkpoint) saver_lm.restore(sess, lm_checkpoint) dataloader_dev.sess = sess batch_time = time() num_processed = 0 progress = 0 while progress < args.num_epochs: global_step, lr = sess.run([tensor_global_step, model.learning_rate]) loss, shape_batch, _, _ = sess.run(model.list_run) num_processed += shape_batch[0] used_time = time()-batch_time batch_time = time() progress = num_processed/args.data.train.size_dataset if global_step % 10 == 0: logging.info('loss: {:.3f}\tbatch: {} lr:{:.6f} time:{:.2f}s {:.3f}% step: {}'.format( loss, shape_batch, lr, used_time, progress*100.0, global_step)) summary.summary_scalar('loss', loss, global_step) summary.summary_scalar('lr', lr, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=global_step, write_meta_graph=True) if global_step % args.dev_step == args.dev_step - 1: cer, wer = dev( step=global_step, dataloader=dataloader_dev, model=model_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, eos_idx=args.eos_idx, min_idx=0, max_idx=args.dim_output-1) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # decode_test( # step=global_step, # sample=args.dataset_test[10], # model=model_infer, # sess=sess, # unit=args.data.unit, # idx2token=args.idx2token, # eos_idx=args.eos_idx, # min_idx=3, # max_idx=None) decode_test( step=global_step, sample=args.dataset_test[10], model=model_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, eos_idx=None, min_idx=0, max_idx=None) if args.num_steps and global_step > args.num_steps: sys.exit() logging.info('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
def infer_lm(): tensor_global_step = tf.train.get_or_create_global_step() dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev model_lm = args.Model_LM( tensor_global_step, is_train=False, args=args.args_lm) args.lm_obj = model_lm saver_lm = tf.train.Saver(model_lm.variables()) args.top_scope = tf.get_variable_scope() # top-level scope args.lm_scope = model_lm.decoder.scope model_infer = args.Model( tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, is_train=False, args=args) saver = tf.train.Saver(model_infer.variables()) size_variables() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) checkpoint_lm = tf.train.latest_checkpoint(args.dirs.lm_checkpoint) saver.restore(sess, checkpoint) saver_lm.restore(sess, checkpoint_lm) total_cer_dist = 0 total_cer_len = 0 total_wer_dist = 0 total_wer_len = 0 with open(args.dir_model.name+'_decode.txt', 'w') as fw: # with open('/mnt/lustre/xushuang/easton/projects/asr-tf/exp/aishell/lm_acc.txt', 'w') as fw: for sample in tqdm(dataset_dev): if not sample: continue dict_feed = {model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model_infer.list_pl[1]: np.array([len(sample['feature'])])} sample_id, shape_batch, beam_decoded = sess.run(model_infer.list_run, feed_dict=dict_feed) # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed) res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1) ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1) list_res_char = list(res_txt) list_ref_char = list(ref_txt) list_res_word = res_txt.split() list_ref_word = ref_txt.split() cer_dist = ed.eval(list_res_char, list_ref_char) cer_len = len(list_ref_char) wer_dist = ed.eval(list_res_word, list_ref_word) wer_len = len(list_ref_word) total_cer_dist += cer_dist total_cer_len += cer_len total_wer_dist += wer_dist total_wer_len += wer_len if cer_len == 0: cer_len = 1000 wer_len = 1000 if wer_dist/wer_len > 0: print('ref ' , ref_txt) for i, decoded, score, rerank_score in zip(range(10), beam_decoded[0][0], beam_decoded[1][0], beam_decoded[2][0]): candidate = array2text(decoded, args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output-1) print('res' ,i , candidate, score, rerank_score) fw.write('res: {}; ref: {}\n'.format(candidate, ref_txt)) fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format(sample['id'], res_txt, ref_txt)) logging.info('current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}'.format( cer_dist/cer_len, wer_dist/wer_len, total_cer_dist/total_cer_len, total_wer_dist/total_wer_len)) logging.info('dev CER {:.3f}: WER: {:.3f}'.format(total_cer_dist/total_cer_len, total_wer_dist/total_wer_len))
def train(): dataloader_train = LMDataLoader(args.dataset_train, 999999, args) # dataloader_train = PTBDataLoader(args.dataset_train, 80, args) tensor_global_step = tf.train.get_or_create_global_step() model = args.Model(tensor_global_step, is_train=True, args=args) model_infer = args.Model(tensor_global_step, is_train=False, args=args) input_pl = tf.placeholder(tf.int32, [None, None]) len_pl = tf.placeholder(tf.int32, [None]) score_T, distribution_T = model.score(input_pl, len_pl) size_variables() start_time = datetime.now() checker = check_to_stop() saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: if args.dirs.checkpoint_init: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) saver.restore(sess, checkpoint) batch_time = time() num_processed = 0 progress = 0 for x, y, len_x, len_y in dataloader_train: global_step, lr = sess.run( [tensor_global_step, model.learning_rate]) feed_dict = { model.list_pl[0]: x, model.list_pl[1]: y, model.list_pl[2]: len_x, model.list_pl[3]: len_y } loss, shape_batch, _ = sess.run(model.list_run, feed_dict=feed_dict) if global_step % 10 == 0: num_tokens = np.sum(len_x) ppl = np.exp(loss / num_tokens) summary.summary_scalar('loss', loss, global_step) summary.summary_scalar('ppl', ppl, global_step) summary.summary_scalar('lr', lr, global_step) num_processed += shape_batch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.dataset_train.size_dataset logging.info( 'ppl: {:.3f}\tshape_batch: {} lr:{:.6f} time:{:.2f}s {:.3f}% step: {}' .format(ppl, shape_batch, lr, used_time, progress * 100.0, global_step)) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step, write_meta_graph=False) if global_step % args.dev_step == args.dev_step - 1: ppl_dev = dev(args.dataset_dev, model_infer, sess) summary.summary_scalar('ppl_dev', ppl_dev, global_step) # accuracy = dev_external('/mnt/lustre/xushuang/easton/projects/asr-tf/exp/aishell/lm_acc.txt', model_infer, input_pl, len_pl, score_T, distribution_T, sess) # summary.summary_scalar('accuracy', accuracy, global_step) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))