def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_multi_batch_bucket() dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.phone, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() G = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, training=False, args=args) vars_ASR = G.trainable_variables() vars_spiker = G.trainable_variables(G.name + '/spiker') size_variables() start_time = datetime.now() saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=30) saver_S = tf.train.Saver(vars_spiker, max_to_keep=30) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) step_bias = 0 config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_ASR.restore(sess, args.dirs.checkpoint_G) step_bias = int(args.dirs.checkpoint_G.split('-')[-1]) if args.dirs.checkpoint_S: saver_S.restore(sess, args.dirs.checkpoint_S) batch_time = time() num_processed = 0 progress = 0 while progress < args.num_epochs: # supervised training global_step, lr = sess.run([tensor_global_step, G.learning_rate]) global_step += step_bias loss_G, shape_batch, _, (ctc_loss, ce_loss, *_) = sess.run(G.list_run) num_processed += shape_batch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size if global_step % 40 == 0: print( 'ctc_loss: {:.2f}, ce_loss: {:.2f} batch: {} lr:{:.1e} {:.2f}s {:.3f}% step: {}' .format(np.mean(ctc_loss), np.mean(ce_loss), shape_batch, lr, used_time, progress * 100, global_step)) if global_step % args.save_step == args.save_step - 1: saver_ASR.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step) print('saved ASR model in', str(args.dir_checkpoint) + '/model-' + str(global_step)) saver_S.save(get_session(sess), str(args.dir_checkpoint / 'model_S'), global_step=global_step) print( 'saved Spiker model in', str(args.dir_checkpoint) + '/model_S-' + str(global_step)) if global_step % args.dev_step == args.dev_step - 1: # if global_step % args.dev_step == 0: per, cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, args=args) summary.summary_scalar('dev_per', per, global_step) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if True: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=G_infer, sess=sess, unit=args.data.unit, args=args) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train_gan(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_multi_batch_bucket() dataReader_untrain = TFDataReader(args.dirs.untrain.tfdata, args=args) batch_untrain = dataReader_untrain.fentch_multi_batch(args.batch_size) args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata) args.data.untrain_size = TFDataReader.read_tfdata_info( args.dirs.untrain.tfdata)['size_dataset'] dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = tfdata_train.cache().repeat().shuffle(1000).\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.phone, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, training=False, args=args) vars_ASR = G.trainable_variables() # vars_G_ocd = G.trainable_variables('Ectc_Docd/ocd_decoder') D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = args.GAN([tensor_global_step0, tensor_global_step1], G, D, batch=batch_train, unbatch=batch_untrain, name='GAN', args=args) size_variables() start_time = datetime.now() saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=10) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_ASR.restore(sess, args.dirs.checkpoint_G) batch_time = time() num_processed = 0 num_processed_unbatch = 0 progress = 0 while progress < args.num_epochs: # semi_supervise global_step, lr_G, lr_D = sess.run([ tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D ]) for _ in range(3): text = sess.run(iter_text) text_lens = get_batch_length(text) shape_text = text.shape loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run( gan.list_train_D, feed_dict={ gan.list_pl[0]: text, gan.list_pl[1]: text_lens }) # loss_D=loss_D_res=loss_D_text=loss_gp=0 (loss_G, ctc_loss, ce_loss, _), (shape_batch, shape_unbatch) = \ sess.run([gan.list_train_G, gan.list_feature_shape]) num_processed += shape_batch[0] # num_processed_unbatch += shape_unbatch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size progress_unbatch = num_processed_unbatch / args.data.untrain_size if global_step % 40 == 0: print('ctc|ce loss: {:.2f}|{:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\t{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.3f}% step: {}'.format( np.mean(ctc_loss), np.mean(ce_loss), loss_D_res, loss_D_text, loss_gp, shape_batch, \ shape_unbatch, lr_G, lr_D, used_time, progress*100, global_step)) summary.summary_scalar('ctc_loss', np.mean(ctc_loss), global_step) summary.summary_scalar('ce_loss', np.mean(ce_loss), global_step) if global_step % args.save_step == args.save_step - 1: saver_ASR.save(get_session(sess), str(args.dir_checkpoint / 'model_G'), global_step=global_step, write_meta_graph=True) print( 'saved G in', str(args.dir_checkpoint) + '/model_G-' + str(global_step)) # saver_G_en.save(get_session(sess), str(args.dir_checkpoint/'model_G_en'), global_step=global_step, write_meta_graph=True) # print('saved model in', str(args.dir_checkpoint)+'/model_G_en-'+str(global_step)) if global_step % args.dev_step == args.dev_step - 1: # if True: per, cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, args=args) summary.summary_scalar('dev_per', per, global_step) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if True: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=G_infer, sess=sess, unit=args.data.unit, args=args) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))