def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args, _shuffle=True, transform=True) dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) batch_train = dataReader_train.fentch_batch_bucket() dataloader_dev = ASRDataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() model = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) model_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) size_variables() start_time = datetime.now() saver = tf.train.Saver(max_to_keep=15) if args.dirs.lm_checkpoint: from tfTools.checkpointTools import list_variables list_lm_vars_pretrained = list_variables(args.dirs.lm_checkpoint) list_lm_vars = model.decoder.lm.variables dict_lm_vars = {} for var in list_lm_vars: if 'embedding' in var.name: for var_pre in list_lm_vars_pretrained: if 'embedding' in var_pre[0]: break else: name = var.name.split(model.decoder.lm.name)[1].split(':')[0] for var_pre in list_lm_vars_pretrained: if name in var_pre[0]: break # 'var_name_in_checkpoint': var_in_graph dict_lm_vars[var_pre[0]] = var saver_lm = tf.train.Saver(dict_lm_vars) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: # for i in range(100): # batch = sess.run(batch_train) # import pdb; pdb.set_trace() # print(batch[0].shape) _, labels, _, len_labels = sess.run(batch_train) if args.dirs.checkpoint: saver.restore(sess, args.dirs.checkpoint) elif args.dirs.lm_checkpoint: lm_checkpoint = tf.train.latest_checkpoint(args.dirs.lm_checkpoint) saver_lm.restore(sess, lm_checkpoint) dataloader_dev.sess = sess batch_time = time() num_processed = 0 progress = 0 while progress < args.num_epochs: global_step, lr = sess.run( [tensor_global_step, model.learning_rate]) loss, shape_batch, _, debug = sess.run(model.list_run) num_processed += shape_batch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size if global_step % 50 == 0: print( 'loss: {:.3f}\tbatch: {} lr:{:.6f} time:{:.2f}s {:.3f}% step: {}' .format(loss, shape_batch, lr, used_time, progress * 100.0, global_step)) summary.summary_scalar('loss', loss, global_step) summary.summary_scalar('lr', lr, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step, write_meta_graph=True) print('saved model in', str(args.dir_checkpoint) + '/model-' + str(global_step)) if global_step % args.dev_step == args.dev_step - 1: cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=model_infer, sess=sess, unit=args.data.unit, token2idx=args.token2idx, idx2token=args.idx2token) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=model_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_multi_batch_bucket() dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.phone, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() G = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, training=False, args=args) vars_ASR = G.trainable_variables() vars_spiker = G.trainable_variables(G.name + '/spiker') size_variables() start_time = datetime.now() saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=30) saver_S = tf.train.Saver(vars_spiker, max_to_keep=30) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) step_bias = 0 config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_ASR.restore(sess, args.dirs.checkpoint_G) step_bias = int(args.dirs.checkpoint_G.split('-')[-1]) if args.dirs.checkpoint_S: saver_S.restore(sess, args.dirs.checkpoint_S) batch_time = time() num_processed = 0 progress = 0 while progress < args.num_epochs: # supervised training global_step, lr = sess.run([tensor_global_step, G.learning_rate]) global_step += step_bias loss_G, shape_batch, _, (ctc_loss, ce_loss, *_) = sess.run(G.list_run) num_processed += shape_batch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size if global_step % 40 == 0: print( 'ctc_loss: {:.2f}, ce_loss: {:.2f} batch: {} lr:{:.1e} {:.2f}s {:.3f}% step: {}' .format(np.mean(ctc_loss), np.mean(ce_loss), shape_batch, lr, used_time, progress * 100, global_step)) if global_step % args.save_step == args.save_step - 1: saver_ASR.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step) print('saved ASR model in', str(args.dir_checkpoint) + '/model-' + str(global_step)) saver_S.save(get_session(sess), str(args.dir_checkpoint / 'model_S'), global_step=global_step) print( 'saved Spiker model in', str(args.dir_checkpoint) + '/model_S-' + str(global_step)) if global_step % args.dev_step == args.dev_step - 1: # if global_step % args.dev_step == 0: per, cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, args=args) summary.summary_scalar('dev_per', per, global_step) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if True: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=G_infer, sess=sess, unit=args.data.unit, args=args) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
transform=False) dataset_test = ASR_phone_char_ArkDataSet(f_scp=args.dirs.test.scp, f_phone=args.dirs.test.phone, f_char=args.dirs.test.char, args=args, _shuffle=False, transform=True) else: dataset_dev = dataset_train = dataset_test = None args.dataset_dev = dataset_dev args.dataset_train = dataset_train args.dataset_test = dataset_test try: args.data.dim_feature = TFDataReader.read_tfdata_info( args.dirs.train.tfdata)['dim_feature'] args.data.train_size = TFDataReader.read_tfdata_info( args.dirs.train.tfdata)['size_dataset'] args.data.dev_size = TFDataReader.read_tfdata_info( args.dirs.dev.tfdata)['size_dataset'] args.data.dim_input = args.data.dim_feature * \ (args.data.right_context + args.data.left_context +1) *\ (3 if args.data.add_delta else 1) except: print("have not converted to tfdata yet: ") # model ## encoder if args.model.encoder.type == 'transformer_encoder': from models.encoders.transformer_encoder import Transformer_Encoder as encoder elif args.model.encoder.type == 'conv_transformer_encoder':
def train_gan(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_multi_batch_bucket() dataReader_untrain = TFDataReader(args.dirs.untrain.tfdata, args=args) batch_untrain = dataReader_untrain.fentch_multi_batch(args.batch_size) args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata) args.data.untrain_size = TFDataReader.read_tfdata_info( args.dirs.untrain.tfdata)['size_dataset'] dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = tfdata_train.cache().repeat().shuffle(1000).\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.phone, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, training=False, args=args) vars_ASR = G.trainable_variables() # vars_G_ocd = G.trainable_variables('Ectc_Docd/ocd_decoder') D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = args.GAN([tensor_global_step0, tensor_global_step1], G, D, batch=batch_train, unbatch=batch_untrain, name='GAN', args=args) size_variables() start_time = datetime.now() saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=10) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_ASR.restore(sess, args.dirs.checkpoint_G) batch_time = time() num_processed = 0 num_processed_unbatch = 0 progress = 0 while progress < args.num_epochs: # semi_supervise global_step, lr_G, lr_D = sess.run([ tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D ]) for _ in range(3): text = sess.run(iter_text) text_lens = get_batch_length(text) shape_text = text.shape loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run( gan.list_train_D, feed_dict={ gan.list_pl[0]: text, gan.list_pl[1]: text_lens }) # loss_D=loss_D_res=loss_D_text=loss_gp=0 (loss_G, ctc_loss, ce_loss, _), (shape_batch, shape_unbatch) = \ sess.run([gan.list_train_G, gan.list_feature_shape]) num_processed += shape_batch[0] # num_processed_unbatch += shape_unbatch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size progress_unbatch = num_processed_unbatch / args.data.untrain_size if global_step % 40 == 0: print('ctc|ce loss: {:.2f}|{:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\t{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.3f}% step: {}'.format( np.mean(ctc_loss), np.mean(ce_loss), loss_D_res, loss_D_text, loss_gp, shape_batch, \ shape_unbatch, lr_G, lr_D, used_time, progress*100, global_step)) summary.summary_scalar('ctc_loss', np.mean(ctc_loss), global_step) summary.summary_scalar('ce_loss', np.mean(ce_loss), global_step) if global_step % args.save_step == args.save_step - 1: saver_ASR.save(get_session(sess), str(args.dir_checkpoint / 'model_G'), global_step=global_step, write_meta_graph=True) print( 'saved G in', str(args.dir_checkpoint) + '/model_G-' + str(global_step)) # saver_G_en.save(get_session(sess), str(args.dir_checkpoint/'model_G_en'), global_step=global_step, write_meta_graph=True) # print('saved model in', str(args.dir_checkpoint)+'/model_G_en-'+str(global_step)) if global_step % args.dev_step == args.dev_step - 1: # if True: per, cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, args=args) summary.summary_scalar('dev_per', per, global_step) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if True: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=G_infer, sess=sess, unit=args.data.unit, args=args) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def check(): import tensorflow as tf from pathlib import Path from utils.dataset import ASR_scp_DataSet, ASRDataLoader from models.utils.tfData import TFDataReader dataset = ASR_scp_DataSet(f_scp=args.dirs.demo.scp, f_trans=args.dirs.demo.trans, args=args, _shuffle=False, transform=False) TFDataSaver(dataset, Path(args.dirs.demo.tfdata), args, size_file=1, max_feat_len=3000).split_save() # train dataReader = TFDataReader(args.dirs.demo.tfdata, args=args, _shuffle=True, transform=True) batch = dataReader.fentch_batch_bucket() # dev dataReader = TFDataReader(args.dirs.demo.tfdata, args=args, _shuffle=False, transform=True) dataLoader = ASRDataLoader(dataset, args, dataReader.feat, dataReader.label, batch_size=2, num_loops=1) # test dataset = ASR_scp_DataSet(f_scp=args.dirs.demo.scp, f_trans=args.dirs.demo.trans, args=args, _shuffle=False, transform=True) dataset_2 = ASR_scp_DataSet(f_scp=args.dirs.demo.scp, f_trans=args.dirs.demo.trans, args=args, _shuffle=False, transform=False) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataLoader.sess = sess import pdb pdb.set_trace() batch = sess.run(batch) sample_dev = next(iter(dataLoader)) sample = dataset[0]