def train(): dataset_train = PinyinDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator( dataset_train, (tf.int32, tf.int32), (tf.TensorShape([None]), tf.TensorShape([None]))) iter_train = tfdata_train.cache().repeat().shuffle(10000).\ padded_batch(args.text_batch_size, ([args.max_label_len], [args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataset_supervise = PinyinDataSet(list_files=[args.dirs.text.supervise], args=args, _shuffle=True) tfdata_supervise = tf.data.Dataset.from_generator( dataset_supervise, (tf.int32, tf.int32), (tf.TensorShape([None]), tf.TensorShape([None]))) iter_supervise = tfdata_supervise.cache().repeat().shuffle(100).\ padded_batch(args.num_supervised, ([args.max_label_len], [args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataset_dev = PinyinDataSet(list_files=[args.dirs.text.dev], args=args, _shuffle=False) tfdata_dev = tf.data.Dataset.from_generator( dataset_dev, (tf.int32, tf.int32), (tf.TensorShape([None]), tf.TensorShape([None]))) tfdata_dev = tfdata_dev.cache().\ padded_batch(args.text_batch_size, ([args.max_label_len], [args.max_label_len])).prefetch(buffer_size=5).\ make_initializable_iterator() iter_dev = tfdata_dev.get_next() tensor_global_step = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = Generator(tensor_global_step, training=True, args=args) G_infer = Generator(tensor_global_step, training=False, args=args) vars_G = G.trainable_variables D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = Conditional_GAN([tensor_global_step0, tensor_global_step1], G, D, batch=None, unbatch=None, name='text_gan', args=args) size_variables() start_time = datetime.now() saver_G = tf.train.Saver(vars_G, max_to_keep=1) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: if args.dirs.checkpoint_G: saver_G.restore(sess, args.dirs.checkpoint_G) print('revocer G from', args.dirs.checkpoint_G) # np.random.seed(0) pinyin_supervise, text_supervise = sess.run(iter_supervise) text_len_supervise = pinyin_len_supervise = get_batch_length( pinyin_supervise) feature_supervise, feature_len_supervise = int2vector( pinyin_supervise, pinyin_len_supervise, hidden_size=args.model.dim_input, uprate=args.uprate) feature_supervise += np.random.randn( *feature_supervise.shape) / args.noise # supervise for _ in range(500): sess.run(G.run_list, feed_dict={ G.list_pl[0]: feature_supervise, G.list_pl[1]: feature_len_supervise, G.list_pl[2]: text_supervise, G.list_pl[3]: text_len_supervise }) batch_time = time() global_step = 0 while global_step < 99999999: # global_step, lr_G, lr_D = sess.run([tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D]) global_step, lr_G = sess.run( [tensor_global_step0, G.learning_rate]) pinyin_supervise, text_supervise = sess.run(iter_supervise) pinyin_len_supervise = get_batch_length(pinyin_supervise) feature_supervise, feature_len_supervise = int2vector( pinyin_supervise, pinyin_len_supervise, hidden_size=args.model.dim_input, uprate=args.uprate) feature_supervise += np.random.randn( *feature_supervise.shape) / args.noise # supervise # for _ in range(1): # loss_G_supervise, _ = sess.run(G.run_list, # feed_dict={G.list_pl[0]:feature_text_supervise, # G.list_pl[1]:text_len_supervise, # G.list_pl[2]:text_supervise, # G.list_pl[3]:text_len_supervise}) # generator input pinyin_G, text_G = sess.run(iter_train) pinyin_lens_G = get_batch_length(pinyin_G) feature_G, lens_G = int2vector(pinyin_G, pinyin_lens_G, hidden_size=args.model.dim_input, uprate=args.uprate) feature_G += np.random.randn(*feature_G.shape) / args.noise loss_G, loss_G_supervise, _ = sess.run(gan.list_train_G, feed_dict={ gan.list_G_pl[0]: feature_G, gan.list_G_pl[1]: lens_G, gan.list_G_pl[2]: feature_supervise, gan.list_G_pl[3]: feature_len_supervise, gan.list_G_pl[4]: text_supervise, gan.list_G_pl[5]: pinyin_len_supervise }) # loss_G = loss_G_supervise = 0 # discriminator input # for _ in range(5): # # np.random.seed(2) # pinyin_G, text_G = sess.run(iter_train) # pinyin_lens_G = get_batch_length(pinyin_G) # feature_G, feature_lens_G = int2vector(pinyin_G, pinyin_lens_G, hidden_size=args.model.dim_input, uprate=args.uprate) # feature_G += np.random.randn(*feature_G.shape)/args.noise # # pinyin_D, text_D = sess.run(iter_train) # text_lens_D = get_batch_length(text_D) # shape_text = text_D.shape # loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D, # feed_dict={gan.list_D_pl[0]:text_D, # gan.list_D_pl[1]:text_lens_D, # gan.list_G_pl[0]:feature_G, # gan.list_G_pl[1]:feature_lens_G}) loss_D_res = loss_D_text = loss_gp = 0 shape_text = [0, 0, 0] # loss_D_res = - loss_G # loss_G = loss_G_supervise = 0.0 # loss_D = loss_D_text = loss_gp = 0.0 # train # if global_step % 5 == 0: # for _ in range(2): # loss_supervise, shape_batch, _, _ = sess.run(G.list_run) # loss_G_supervise = 0 used_time = time() - batch_time batch_time = time() if global_step % 10 == 0: # print('loss_G: {:.2f} loss_G_supervise: {:.2f} loss_D_res: {:.2f} loss_D_text: {:.2f} step: {}'.format( # loss_G, loss_G_supervise, loss_D_res, loss_D_text, global_step)) print( 'loss_G_supervise: {:.2f} loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}\tlr:{:.1e} {:.2f}s step: {}' .format(loss_G_supervise, loss_D_res, loss_D_text, loss_gp, shape_text, lr_G, used_time, global_step)) # summary.summary_scalar('loss_G', loss_G, global_step) # summary.summary_scalar('loss_D', loss_D, global_step) # summary.summary_scalar('lr_G', lr_G, global_step) # summary.summary_scalar('lr_D', lr_D, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step, write_meta_graph=True) print('saved model as', str(args.dir_checkpoint) + '/model-' + str(global_step)) # if global_step % args.dev_step == args.dev_step - 1: if global_step % args.dev_step == 1: pinyin_G_dev, text_G_dev = dev(iter_dev, tfdata_dev, dataset_dev, sess, G_infer) # if global_step % args.decode_step == args.decode_step - 1: if global_step % args.decode_step == args.decode_step - 1: decode(pinyin_G_dev, text_G_dev, sess, G_infer) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_batch_bucket() dataReader_untrain = TFReader(args.dirs.untrain.tfdata, args=args) batch_untrain = dataReader_untrain.fentch_batch(args.batch_size) # batch_untrain = dataReader_untrain.fentch_batch_bucket() args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata) args.data.untrain_size = TFData.read_tfdata_info( args.dirs.untrain.tfdata)['size_dataset'] dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = tfdata_train.cache().repeat().shuffle(1000).\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() feat, label = readTFRecord(args.dirs.dev.tfdata, args, _shuffle=False, transform=True) dataloader_dev = ASRDataLoader(args.dataset_dev, args, feat, label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) vars_G = G.trainable_variables D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = args.GAN([tensor_global_step0, tensor_global_step1], G, D, batch=batch_train, unbatch=batch_untrain, name='GAN', args=args) size_variables() start_time = datetime.now() saver_G = tf.train.Saver(vars_G, max_to_keep=1) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_G.restore(sess, args.dirs.checkpoint_G) for i in range(2000): ctc_loss, *_ = sess.run(G.list_run) if i % 400: print('ctc_loss: {:.2f}'.format(ctc_loss)) batch_time = time() num_processed = 0 num_processed_unbatch = 0 progress = 0 while progress < args.num_epochs: global_step, lr_G, lr_D = sess.run([ tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D ]) # untrain text = sess.run(iter_text) text_lens = get_batch_length(text) shape_text = text.shape # loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D, # feed_dict={gan.list_pl[0]:text, # gan.list_pl[1]:text_lens}) loss_D = loss_D_res = loss_D_text = loss_gp = 0 (loss_G, loss_G_supervise, _), (shape_feature, shape_unfeature) = \ sess.run([gan.list_train_G, gan.list_feature_shape]) num_processed += shape_feature[0] num_processed_unbatch += shape_unfeature[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size progress_unbatch = num_processed_unbatch / args.data.untrain_size if global_step % 20 == 0: print( 'loss_supervise: {:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}|{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.1f}%|{:.1f}% step: {}' .format(loss_G_supervise, loss_D_res, loss_D_text, loss_gp, shape_feature, shape_unfeature, shape_text, lr_G, lr_D, used_time, progress * 100.0, progress_unbatch * 100.0, global_step)) # summary.summary_scalar('loss_G', loss_G, global_step) # summary.summary_scalar('loss_D', loss_D, global_step) # summary.summary_scalar('lr_G', lr_G, global_step) # summary.summary_scalar('lr_D', lr_D, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step, write_meta_graph=True) # if global_step % args.dev_step == args.dev_step - 1: if global_step % args.dev_step == 0: cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) # summary.summary_scalar('dev_cer', cer, global_step) # summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: decode_test(step=global_step, sample=args.dataset_test[10], model=G_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def infer_lm(): tensor_global_step = tf.train.get_or_create_global_step() dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev model_lm = args.Model_LM(tensor_global_step, training=False, args=args.args_lm) args.lm_obj = model_lm saver_lm = tf.train.Saver(model_lm.variables()) args.top_scope = tf.get_variable_scope() # top-level scope args.lm_scope = model_lm.decoder.scope model_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) saver = tf.train.Saver(model_infer.variables()) size_variables() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) checkpoint_lm = tf.train.latest_checkpoint(args.dirs.lm_checkpoint) saver.restore(sess, checkpoint) saver_lm.restore(sess, checkpoint_lm) total_cer_dist = 0 total_cer_len = 0 total_wer_dist = 0 total_wer_len = 0 with open(args.dir_model.name + '_decode.txt', 'w') as fw: # with open('/mnt/lustre/xushuang/easton/projects/asr-tf/exp/aishell/lm_acc.txt', 'w') as fw: for sample in tqdm(dataset_dev): if not sample: continue dict_feed = { model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model_infer.list_pl[1]: np.array([len(sample['feature'])]) } sample_id, shape_batch, beam_decoded = sess.run( model_infer.list_run, feed_dict=dict_feed) # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed) res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output - 1) ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output - 1) list_res_char = list(res_txt) list_ref_char = list(ref_txt) list_res_word = res_txt.split() list_ref_word = ref_txt.split() cer_dist = ed.eval(list_res_char, list_ref_char) cer_len = len(list_ref_char) wer_dist = ed.eval(list_res_word, list_ref_word) wer_len = len(list_ref_word) total_cer_dist += cer_dist total_cer_len += cer_len total_wer_dist += wer_dist total_wer_len += wer_len if cer_len == 0: cer_len = 1000 wer_len = 1000 if wer_dist / wer_len > 0: print('ref ', ref_txt) for i, decoded, score, rerank_score in zip( range(10), beam_decoded[0][0], beam_decoded[1][0], beam_decoded[2][0]): candidate = array2text(decoded, args.data.unit, args.idx2token, min_idx=0, max_idx=args.dim_output - 1) print('res', i, candidate, score, rerank_score) fw.write('res: {}; ref: {}\n'.format( candidate, ref_txt)) fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format( sample['id'], res_txt, ref_txt)) logging.info( 'current cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f}' .format(cer_dist / cer_len, wer_dist / wer_len, total_cer_dist / total_cer_len, total_wer_dist / total_wer_len)) logging.info('dev CER {:.3f}: WER: {:.3f}'.format( total_cer_dist / total_cer_len, total_wer_dist / total_wer_len))
def infer(): tensor_global_step = tf.train.get_or_create_global_step() model_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) dataset_dev = args.dataset_test if args.dataset_test else args.dataset_dev saver = tf.train.Saver(max_to_keep=40) size_variables() config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: checkpoint = tf.train.latest_checkpoint(args.dirs.checkpoint_init) saver.restore(sess, checkpoint) total_cer_dist = 0 total_cer_len = 0 total_wer_dist = 0 total_wer_len = 0 with open(args.dir_model.name + '_decode.txt', 'w') as fw: for i, sample in enumerate(dataset_dev): if not sample: continue dict_feed = { model_infer.list_pl[0]: np.expand_dims(sample['feature'], axis=0), model_infer.list_pl[1]: np.array([len(sample['feature'])]) } sample_id, shape_batch, _ = sess.run(model_infer.list_run, feed_dict=dict_feed) # decoded, sample_id, decoded_sparse = sess.run(model_infer.list_run, feed_dict=dict_feed) res_txt = array2text(sample_id[0], args.data.unit, args.idx2token, args.token2idx) ref_txt = array2text(sample['label'], args.data.unit, args.idx2token, args.token2idx) list_res_char = list(res_txt) list_ref_char = list(ref_txt) list_res_word = res_txt.split() list_ref_word = ref_txt.split() cer_dist = ed.eval(list_res_char, list_ref_char) cer_len = len(list_ref_char) wer_dist = ed.eval(list_res_word, list_ref_word) wer_len = len(list_ref_word) total_cer_dist += cer_dist total_cer_len += cer_len total_wer_dist += wer_dist total_wer_len += wer_len if cer_len == 0: cer_len = 1000 wer_len = 1000 if wer_dist / wer_len > 0: fw.write('id:\t{} \nres:\t{}\nref:\t{}\n\n'.format( sample['id'], res_txt, ref_txt)) sys.stdout.write( '\rcurrent cer: {:.3f}, wer: {:.3f};\tall cer {:.3f}, wer: {:.3f} {}/{} {:.2f}%' .format(cer_dist / cer_len, wer_dist / wer_len, total_cer_dist / total_cer_len, total_wer_dist / total_wer_len, i, len(dataset_dev), i / len(dataset_dev) * 100)) sys.stdout.flush() logging.info('dev CER {:.3f}: WER: {:.3f}'.format( total_cer_dist / total_cer_len, total_wer_dist / total_wer_len))
def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args, _shuffle=True, transform=True) dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) batch_train = dataReader_train.fentch_batch_bucket() dataloader_dev = ASRDataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() model = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) model_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) size_variables() start_time = datetime.now() saver = tf.train.Saver(max_to_keep=15) if args.dirs.lm_checkpoint: from tfTools.checkpointTools import list_variables list_lm_vars_pretrained = list_variables(args.dirs.lm_checkpoint) list_lm_vars = model.decoder.lm.variables dict_lm_vars = {} for var in list_lm_vars: if 'embedding' in var.name: for var_pre in list_lm_vars_pretrained: if 'embedding' in var_pre[0]: break else: name = var.name.split(model.decoder.lm.name)[1].split(':')[0] for var_pre in list_lm_vars_pretrained: if name in var_pre[0]: break # 'var_name_in_checkpoint': var_in_graph dict_lm_vars[var_pre[0]] = var saver_lm = tf.train.Saver(dict_lm_vars) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: # for i in range(100): # batch = sess.run(batch_train) # import pdb; pdb.set_trace() # print(batch[0].shape) _, labels, _, len_labels = sess.run(batch_train) if args.dirs.checkpoint: saver.restore(sess, args.dirs.checkpoint) elif args.dirs.lm_checkpoint: lm_checkpoint = tf.train.latest_checkpoint(args.dirs.lm_checkpoint) saver_lm.restore(sess, lm_checkpoint) dataloader_dev.sess = sess batch_time = time() num_processed = 0 progress = 0 while progress < args.num_epochs: global_step, lr = sess.run( [tensor_global_step, model.learning_rate]) loss, shape_batch, _, debug = sess.run(model.list_run) num_processed += shape_batch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size if global_step % 50 == 0: print( 'loss: {:.3f}\tbatch: {} lr:{:.6f} time:{:.2f}s {:.3f}% step: {}' .format(loss, shape_batch, lr, used_time, progress * 100.0, global_step)) summary.summary_scalar('loss', loss, global_step) summary.summary_scalar('lr', lr, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step, write_meta_graph=True) print('saved model in', str(args.dir_checkpoint) + '/model-' + str(global_step)) if global_step % args.dev_step == args.dev_step - 1: cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=model_infer, sess=sess, unit=args.data.unit, token2idx=args.token2idx, idx2token=args.idx2token) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=model_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_multi_batch_bucket() dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.phone, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() G = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, training=False, args=args) vars_ASR = G.trainable_variables() vars_spiker = G.trainable_variables(G.name + '/spiker') size_variables() start_time = datetime.now() saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=30) saver_S = tf.train.Saver(vars_spiker, max_to_keep=30) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) step_bias = 0 config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_ASR.restore(sess, args.dirs.checkpoint_G) step_bias = int(args.dirs.checkpoint_G.split('-')[-1]) if args.dirs.checkpoint_S: saver_S.restore(sess, args.dirs.checkpoint_S) batch_time = time() num_processed = 0 progress = 0 while progress < args.num_epochs: # supervised training global_step, lr = sess.run([tensor_global_step, G.learning_rate]) global_step += step_bias loss_G, shape_batch, _, (ctc_loss, ce_loss, *_) = sess.run(G.list_run) num_processed += shape_batch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size if global_step % 40 == 0: print( 'ctc_loss: {:.2f}, ce_loss: {:.2f} batch: {} lr:{:.1e} {:.2f}s {:.3f}% step: {}' .format(np.mean(ctc_loss), np.mean(ce_loss), shape_batch, lr, used_time, progress * 100, global_step)) if global_step % args.save_step == args.save_step - 1: saver_ASR.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step) print('saved ASR model in', str(args.dir_checkpoint) + '/model-' + str(global_step)) saver_S.save(get_session(sess), str(args.dir_checkpoint / 'model_S'), global_step=global_step) print( 'saved Spiker model in', str(args.dir_checkpoint) + '/model_S-' + str(global_step)) if global_step % args.dev_step == args.dev_step - 1: # if global_step % args.dev_step == 0: per, cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, args=args) summary.summary_scalar('dev_per', per, global_step) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if True: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=G_infer, sess=sess, unit=args.data.unit, args=args) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train_gan(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_multi_batch_bucket() dataReader_untrain = TFDataReader(args.dirs.untrain.tfdata, args=args) batch_untrain = dataReader_untrain.fentch_multi_batch(args.batch_size) args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata) args.data.untrain_size = TFDataReader.read_tfdata_info( args.dirs.untrain.tfdata)['size_dataset'] dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = tfdata_train.cache().repeat().shuffle(1000).\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.phone, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, training=False, args=args) vars_ASR = G.trainable_variables() # vars_G_ocd = G.trainable_variables('Ectc_Docd/ocd_decoder') D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = args.GAN([tensor_global_step0, tensor_global_step1], G, D, batch=batch_train, unbatch=batch_untrain, name='GAN', args=args) size_variables() start_time = datetime.now() saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=10) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_ASR.restore(sess, args.dirs.checkpoint_G) batch_time = time() num_processed = 0 num_processed_unbatch = 0 progress = 0 while progress < args.num_epochs: # semi_supervise global_step, lr_G, lr_D = sess.run([ tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D ]) for _ in range(3): text = sess.run(iter_text) text_lens = get_batch_length(text) shape_text = text.shape loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run( gan.list_train_D, feed_dict={ gan.list_pl[0]: text, gan.list_pl[1]: text_lens }) # loss_D=loss_D_res=loss_D_text=loss_gp=0 (loss_G, ctc_loss, ce_loss, _), (shape_batch, shape_unbatch) = \ sess.run([gan.list_train_G, gan.list_feature_shape]) num_processed += shape_batch[0] # num_processed_unbatch += shape_unbatch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size progress_unbatch = num_processed_unbatch / args.data.untrain_size if global_step % 40 == 0: print('ctc|ce loss: {:.2f}|{:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\t{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.3f}% step: {}'.format( np.mean(ctc_loss), np.mean(ce_loss), loss_D_res, loss_D_text, loss_gp, shape_batch, \ shape_unbatch, lr_G, lr_D, used_time, progress*100, global_step)) summary.summary_scalar('ctc_loss', np.mean(ctc_loss), global_step) summary.summary_scalar('ce_loss', np.mean(ce_loss), global_step) if global_step % args.save_step == args.save_step - 1: saver_ASR.save(get_session(sess), str(args.dir_checkpoint / 'model_G'), global_step=global_step, write_meta_graph=True) print( 'saved G in', str(args.dir_checkpoint) + '/model_G-' + str(global_step)) # saver_G_en.save(get_session(sess), str(args.dir_checkpoint/'model_G_en'), global_step=global_step, write_meta_graph=True) # print('saved model in', str(args.dir_checkpoint)+'/model_G_en-'+str(global_step)) if global_step % args.dev_step == args.dev_step - 1: # if True: per, cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, args=args) summary.summary_scalar('dev_per', per, global_step) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if True: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=G_infer, sess=sess, unit=args.data.unit, args=args) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): args.num_gpus = len(args.gpus.split(',')) - 1 args.list_gpus = ['/gpu:{}'.format(i) for i in range(args.num_gpus)] # bucket if args.bucket_boundaries: args.list_bucket_boundaries = [int(i) for i in args.bucket_boundaries.split(',')] assert args.num_batch_tokens args.list_batch_size = ([int(args.num_batch_tokens / boundary) * args.num_gpus for boundary in (args.list_bucket_boundaries)] + [args.num_gpus]) args.list_infer_batch_size = ([int(args.num_batch_tokens / boundary) for boundary in (args.list_bucket_boundaries)] + [1]) args.batch_size *= args.num_gpus logging.info('\nbucket_boundaries: {} \nbatch_size: {}'.format( args.list_bucket_boundaries, args.list_batch_size)) print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_batch_bucket() dataReader_untrain = TFReader(args.dirs.untrain.tfdata, args=args) # batch_untrain = dataReader_untrain.fentch_batch(args.batch_size) batch_untrain = dataReader_untrain.fentch_batch_bucket() args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata) args.data.untrain_size = TFData.read_tfdata_info(args.dirs.untrain.tfdata)['size_dataset'] feat, label = readTFRecord(args.dirs.dev.tfdata, args, _shuffle=False, transform=True) dataloader_dev = ASRDataLoader(args.dataset_dev, args, feat, label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() # get dataset ngram ngram_py, total_num = read_ngram(args.EODM.top_k, args.dirs.text.ngram, args.token2idx, type='list') kernel, py = ngram2kernel(ngram_py, args.EODM.ngram, args.EODM.top_k, args.dim_output) G = args.Model( tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, kernel=kernel, py=py, batch=batch_train, unbatch=batch_untrain, training=True, args=args) args.list_gpus = ['/gpu:{}'.format(args.num_gpus)] G_infer = args.Model( tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) vars_G = G.variables() size_variables() start_time = datetime.now() saver_G = tf.train.Saver(vars_G, max_to_keep=1) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_G.restore(sess, args.dirs.checkpoint_G) batch_time = time() num_processed = 0 num_processed_unbatch = 0 progress = 0 progress_unbatch = 0 loss_CTC = 0.0; shape_batch = [0,0,0] loss_EODM = 0.0; shape_unbatch=[0,0,0] while progress < args.num_epochs: global_step, lr = sess.run([tensor_global_step, G.learning_rate]) if global_step % 2 == 0: loss_CTC, shape_batch, _ = sess.run(G.list_run) # loss_CTC = 0.0; shape_batch = [0,0,0] else: loss_EODM, shape_unbatch, _ = sess.run(G.list_run_EODM) # loss_EODM = 0.0; shape_unbatch=[0,0,0] num_processed += shape_batch[0] num_processed_unbatch += shape_unbatch[0] used_time = time()-batch_time batch_time = time() progress = num_processed/args.data.train_size progress_unbatch = num_processed_unbatch/args.data.untrain_size if global_step % 50 == 0: logging.info('loss: {:.2f}|{:.2f}\tbatch: {}|{} lr:{:.6f} time:{:.2f}s {:.2f}% {:.2f}% step: {}'.format( loss_CTC, loss_EODM, shape_batch, shape_unbatch, lr, used_time, progress*100.0, progress_unbatch*100.0, global_step)) # summary.summary_scalar('loss_G', loss_G, global_step) # summary.summary_scalar('loss_D', loss_D, global_step) # summary.summary_scalar('lr_G', lr_G, global_step) # summary.summary_scalar('lr_D', lr_D, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=global_step, write_meta_graph=True) if global_step % args.dev_step == args.dev_step - 1: # if global_step % args.dev_step == 0: cer, wer = dev( step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) # summary.summary_scalar('dev_cer', cer, global_step) # summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if global_step: decode_test( step=global_step, sample=args.dataset_test[10], model=G_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) logging.info('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))