def train(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_batch_bucket() dataReader_untrain = TFReader(args.dirs.untrain.tfdata, args=args) batch_untrain = dataReader_untrain.fentch_batch(args.batch_size) # batch_untrain = dataReader_untrain.fentch_batch_bucket() args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata) args.data.untrain_size = TFData.read_tfdata_info( args.dirs.untrain.tfdata)['size_dataset'] dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = tfdata_train.cache().repeat().shuffle(1000).\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() feat, label = readTFRecord(args.dirs.dev.tfdata, args, _shuffle=False, transform=True) dataloader_dev = ASRDataLoader(args.dataset_dev, args, feat, label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, decoder=args.model.decoder.type, training=False, args=args) vars_G = G.trainable_variables D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = args.GAN([tensor_global_step0, tensor_global_step1], G, D, batch=batch_train, unbatch=batch_untrain, name='GAN', args=args) size_variables() start_time = datetime.now() saver_G = tf.train.Saver(vars_G, max_to_keep=1) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_G.restore(sess, args.dirs.checkpoint_G) for i in range(2000): ctc_loss, *_ = sess.run(G.list_run) if i % 400: print('ctc_loss: {:.2f}'.format(ctc_loss)) batch_time = time() num_processed = 0 num_processed_unbatch = 0 progress = 0 while progress < args.num_epochs: global_step, lr_G, lr_D = sess.run([ tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D ]) # untrain text = sess.run(iter_text) text_lens = get_batch_length(text) shape_text = text.shape # loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D, # feed_dict={gan.list_pl[0]:text, # gan.list_pl[1]:text_lens}) loss_D = loss_D_res = loss_D_text = loss_gp = 0 (loss_G, loss_G_supervise, _), (shape_feature, shape_unfeature) = \ sess.run([gan.list_train_G, gan.list_feature_shape]) num_processed += shape_feature[0] num_processed_unbatch += shape_unfeature[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size progress_unbatch = num_processed_unbatch / args.data.untrain_size if global_step % 20 == 0: print( 'loss_supervise: {:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}|{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.1f}%|{:.1f}% step: {}' .format(loss_G_supervise, loss_D_res, loss_D_text, loss_gp, shape_feature, shape_unfeature, shape_text, lr_G, lr_D, used_time, progress * 100.0, progress_unbatch * 100.0, global_step)) # summary.summary_scalar('loss_G', loss_G, global_step) # summary.summary_scalar('loss_D', loss_D, global_step) # summary.summary_scalar('lr_G', lr_G, global_step) # summary.summary_scalar('lr_D', lr_D, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint / 'model'), global_step=global_step, write_meta_graph=True) # if global_step % args.dev_step == args.dev_step - 1: if global_step % args.dev_step == 0: cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) # summary.summary_scalar('dev_cer', cer, global_step) # summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: decode_test(step=global_step, sample=args.dataset_test[10], model=G_infer, sess=sess, unit=args.data.unit, idx2token=args.idx2token, token2idx=args.token2idx) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): with tf.device("/cpu:0"): dataset_train = ASR_align_DataSet( trans_file=args.dirs.train.trans, align_file=args.dirs.train.align, uttid2wav=args.dirs.train.wav_scp, feat_len_file=args.dirs.train.feat_len, args=args, _shuffle=True, transform=True) dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans, align_file=args.dirs.dev.align, uttid2wav=args.dirs.dev.wav_scp, feat_len_file=args.dirs.dev.feat_len, args=args, _shuffle=False, transform=True) dataset_train_supervise = ASR_align_DataSet( trans_file=args.dirs.train_supervise.trans, align_file=args.dirs.train_supervise.align, uttid2wav=args.dirs.train_supervise.wav_scp, feat_len_file=args.dirs.train_supervise.feat_len, args=args, _shuffle=False, transform=True) feature_train_supervise = TFData( dataset=dataset_train_supervise, dir_save=args.dirs.train_supervise.tfdata, args=args).read() feature_train = TFData(dataset=dataset_train, dir_save=args.dirs.train.tfdata, args=args).read() feature_dev = TFData(dataset=dataset_dev, dir_save=args.dirs.dev.tfdata, args=args).read() iter_feature_supervise = iter( feature_train_supervise.cache().repeat().padded_batch( args.batch_size, ((), [None, args.dim_input])).prefetch(buffer_size=5)) iter_feature_train = iter( feature_train.cache().repeat().shuffle(500).padded_batch( args.batch_size, ((), [None, args.dim_input])).prefetch(buffer_size=5)) feature_dev = feature_dev.padded_batch(args.batch_size, ((), [None, args.dim_input])) dataset_text = TextDataSet(list_files=[args.dirs.lm.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map( lambda x: x[:args.model.D.max_label_len]).padded_batch( args.batch_size, ([args.model.D.max_label_len])).prefetch(buffer_size=5)) # create model paremeters G = PhoneClassifier(args) D = PhoneDiscriminator3(args) G.summary() D.summary() optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.1, beta_2=0.5) optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr, beta_1=0.5, beta_2=0.9) writer = tf.summary.create_file_writer(str(args.dir_log)) ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G) ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=20) step = 0 best = 999 phrase_ctc = False step_retention = 0 # if a checkpoint exists, restore the latest checkpoint. if args.dirs.checkpoint: ckpt.restore(args.dirs.checkpoint) print('checkpoint {} restored!!'.format(args.dirs.checkpoint)) step = int(args.dirs.checkpoint.split('-')[-1]) start_time = datetime.now() num_processed = 0 progress = 0 while step < 99999999: start = time() if phrase_ctc: # CTC phrase uttids, x = next(iter_feature_supervise) trans = dataset_train_supervise.get_attrs('trans', uttids.numpy()) loss_G_ctc = train_CTC_G(x, trans, G, D, optimizer_G) else: # GAN phrase ## D sub phrase for _ in range(args.opti.D_G_rate): uttids, x = next(iter_feature_train) text = next(iter_text) P_Real = tf.one_hot(text, args.dim_output) loss_D, loss_D_fake, loss_D_real, gp = train_D( x, P_Real, text > 0, G, D, optimizer_D, args.lambda_gp, args.model.D.max_label_len) ## G sub phrase uttids, x = next(iter_feature_train) _uttids, _x = next(iter_feature_supervise) _trans = dataset_train_supervise.get_attrs('trans', _uttids.numpy()) train_GAN_G(x, _x, _trans, G, D, optimizer_G, args.model.D.max_label_len, args.lambda_supervise) num_processed += len(x) progress = num_processed / args.data.train_size if step % 10 == 0: if phrase_ctc: print( 'loss_G_ctc: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% iter: {}' .format(loss_G_ctc, x.shape, time() - start, progress * 100.0, step)) with writer.as_default(): tf.summary.scalar("costs/loss_G_supervise", loss_G_ctc, step=step) else: print( 'loss_GAN: {:.3f}|{:.3f}|{:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}' .format(loss_D_fake, loss_D_real, gp, x.shape, text.shape, time() - start, progress * 100.0, step)) if step % args.dev_step == 0: cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G) if cer < best: best = cer G_values = save_varibales(G) elif step_retention < 2000: pass else: phrase_ctc = not phrase_ctc step_retention = 0 load_values(G, G_values) print('====== switching phrase ======') with writer.as_default(): tf.summary.scalar("performance/cer", cer, step=step) if step % args.decode_step == 0: monitor(dataset_dev[0], G) if step % args.save_step == 0: save_path = ckpt_manager.save(step) print('save model {}'.format(save_path)) step += 1 step_retention += 1 print('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): with tf.device("/cpu:0"): dataset_train = ASR_align_DataSet( trans_file=args.dirs.train.trans, align_file=args.dirs.train.align, uttid2wav=args.dirs.train.wav_scp, feat_len_file=args.dirs.train.feat_len, args=args, _shuffle=True, transform=True) dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans, align_file=args.dirs.dev.align, uttid2wav=args.dirs.dev.wav_scp, feat_len_file=args.dirs.dev.feat_len, args=args, _shuffle=False, transform=True) dataset_train_supervise = ASR_align_DataSet( trans_file=args.dirs.train_supervise.trans, align_file=args.dirs.train_supervise.align, uttid2wav=args.dirs.train_supervise.wav_scp, feat_len_file=args.dirs.train_supervise.feat_len, args=args, _shuffle=False, transform=True) feature_train_supervise = TFData( dataset=dataset_train_supervise, dir_save=args.dirs.train_supervise.tfdata, args=args).read() feature_train = TFData(dataset=dataset_train, dir_save=args.dirs.train.tfdata, args=args).read() feature_dev = TFData(dataset=dataset_dev, dir_save=args.dirs.dev.tfdata, args=args).read() supervise_uttids, supervise_x = next(iter(feature_train_supervise.take(args.num_supervised).\ padded_batch(args.num_supervised, ((), [None, args.dim_input])))) supervise_aligns = dataset_train_supervise.get_attrs( 'align', supervise_uttids.numpy()) supervise_bounds = dataset_train_supervise.get_attrs( 'bounds', supervise_uttids.numpy()) iter_feature_train = iter( feature_train.cache().repeat().shuffle(500).padded_batch( args.batch_size, ((), [None, args.dim_input])).prefetch(buffer_size=5)) feature_dev = feature_dev.padded_batch(args.batch_size, ((), [None, args.dim_input])) dataset_text = TextDataSet(list_files=[args.dirs.lm.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map( lambda x: x[:args.model.D.max_label_len]).padded_batch( args.batch_size, ([args.model.D.max_label_len])).prefetch(buffer_size=5)) # create model paremeters G = PhoneClassifier(args) D = PhoneDiscriminator3(args) G.summary() D.summary() optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.5, beta_2=0.9) optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr, beta_1=0.5, beta_2=0.9) writer = tf.summary.create_file_writer(str(args.dir_log)) ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G) ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=20) step = 0 # if a checkpoint exists, restore the latest checkpoint. if args.dirs.checkpoint: _ckpt_manager = tf.train.CheckpointManager(ckpt, args.dirs.checkpoint, max_to_keep=1) ckpt.restore(_ckpt_manager.latest_checkpoint) print('checkpoint {} restored!!'.format( _ckpt_manager.latest_checkpoint)) step = int(_ckpt_manager.latest_checkpoint.split('-')[-1]) start_time = datetime.now() num_processed = 0 progress = 0 while step < 99999999: start = time() for _ in range(args.opti.D_G_rate): uttids, x = next(iter_feature_train) stamps = dataset_train.get_attrs('stamps', uttids.numpy()) text = next(iter_text) P_Real = tf.one_hot(text, args.dim_output) cost_D, gp = train_D(x, stamps, P_Real, text > 0, G, D, optimizer_D, args.lambda_gp, args.model.D.max_label_len) # cost_D, gp = train_D(x, P_Real, text>0, G, D, optimizer_D, # args.lambda_gp, args.model.G.len_seq, args.model.D.max_label_len) uttids, x = next(iter_feature_train) stamps = dataset_train.get_attrs('stamps', uttids.numpy()) cost_G, fs = train_G(x, stamps, G, D, optimizer_G, args.lambda_fs) # cost_G, fs = train_G(x, G, D, optimizer_G, # args.lambda_fs, args.model.G.len_seq, args.model.D.max_label_len) loss_supervise = train_G_supervised(supervise_x, supervise_aligns, G, optimizer_G, args.dim_output, args.lambda_supervision) # loss_supervise, bounds_loss = train_G_bounds_supervised( # supervise_x, supervise_bounds, supervise_aligns, G, optimizer_G, args.dim_output) num_processed += len(x) progress = num_processed / args.data.train_size if step % 10 == 0: print( 'cost_G: {:.3f}|{:.3f}\tcost_D: {:.3f}|{:.3f}\tloss_supervise: {:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}' .format(cost_G, fs, cost_D, gp, loss_supervise, x.shape, text.shape, time() - start, progress * 100.0, step)) with writer.as_default(): tf.summary.scalar("costs/cost_G", cost_G, step=step) tf.summary.scalar("costs/cost_D", cost_D, step=step) tf.summary.scalar("costs/gp", gp, step=step) tf.summary.scalar("costs/fs", fs, step=step) tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step) if step % args.dev_step == 0: # fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G) fer, cer_0 = evaluate(feature_dev, dataset_dev, args.data.dev_size, G, beam_size=0, with_stamp=True) fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G, beam_size=0, with_stamp=False) with writer.as_default(): tf.summary.scalar("performance/fer", fer, step=step) tf.summary.scalar("performance/cer_0", cer_0, step=step) tf.summary.scalar("performance/cer", cer, step=step) if step % args.decode_step == 0: monitor(dataset_dev[0], G) if step % args.save_step == 0: save_path = ckpt_manager.save(step) print('save model {}'.format(save_path)) step += 1 print('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train_gan(): print('reading data form ', args.dirs.train.tfdata) dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args) batch_train = dataReader_train.fentch_multi_batch_bucket() dataReader_untrain = TFDataReader(args.dirs.untrain.tfdata, args=args) batch_untrain = dataReader_untrain.fentch_multi_batch(args.batch_size) args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata) args.data.untrain_size = TFDataReader.read_tfdata_info( args.dirs.untrain.tfdata)['size_dataset'] dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = tfdata_train.cache().repeat().shuffle(1000).\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataReader_dev = TFDataReader(args.dirs.dev.tfdata, args=args, _shuffle=False, transform=True) dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev, args, dataReader_dev.feat, dataReader_dev.phone, dataReader_dev.label, batch_size=args.batch_size, num_loops=1) tensor_global_step = tf.train.get_or_create_global_step() tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, batch=batch_train, training=True, args=args) G_infer = args.Model(tensor_global_step, encoder=args.model.encoder.type, encoder2=args.model.encoder2.type, decoder=args.model.decoder.type, training=False, args=args) vars_ASR = G.trainable_variables() # vars_G_ocd = G.trainable_variables('Ectc_Docd/ocd_decoder') D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = args.GAN([tensor_global_step0, tensor_global_step1], G, D, batch=batch_train, unbatch=batch_untrain, name='GAN', args=args) size_variables() start_time = datetime.now() saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=10) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: dataloader_dev.sess = sess if args.dirs.checkpoint_G: saver_ASR.restore(sess, args.dirs.checkpoint_G) batch_time = time() num_processed = 0 num_processed_unbatch = 0 progress = 0 while progress < args.num_epochs: # semi_supervise global_step, lr_G, lr_D = sess.run([ tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D ]) for _ in range(3): text = sess.run(iter_text) text_lens = get_batch_length(text) shape_text = text.shape loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run( gan.list_train_D, feed_dict={ gan.list_pl[0]: text, gan.list_pl[1]: text_lens }) # loss_D=loss_D_res=loss_D_text=loss_gp=0 (loss_G, ctc_loss, ce_loss, _), (shape_batch, shape_unbatch) = \ sess.run([gan.list_train_G, gan.list_feature_shape]) num_processed += shape_batch[0] # num_processed_unbatch += shape_unbatch[0] used_time = time() - batch_time batch_time = time() progress = num_processed / args.data.train_size progress_unbatch = num_processed_unbatch / args.data.untrain_size if global_step % 40 == 0: print('ctc|ce loss: {:.2f}|{:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\t{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.3f}% step: {}'.format( np.mean(ctc_loss), np.mean(ce_loss), loss_D_res, loss_D_text, loss_gp, shape_batch, \ shape_unbatch, lr_G, lr_D, used_time, progress*100, global_step)) summary.summary_scalar('ctc_loss', np.mean(ctc_loss), global_step) summary.summary_scalar('ce_loss', np.mean(ce_loss), global_step) if global_step % args.save_step == args.save_step - 1: saver_ASR.save(get_session(sess), str(args.dir_checkpoint / 'model_G'), global_step=global_step, write_meta_graph=True) print( 'saved G in', str(args.dir_checkpoint) + '/model_G-' + str(global_step)) # saver_G_en.save(get_session(sess), str(args.dir_checkpoint/'model_G_en'), global_step=global_step, write_meta_graph=True) # print('saved model in', str(args.dir_checkpoint)+'/model_G_en-'+str(global_step)) if global_step % args.dev_step == args.dev_step - 1: # if True: per, cer, wer = dev(step=global_step, dataloader=dataloader_dev, model=G_infer, sess=sess, unit=args.data.unit, args=args) summary.summary_scalar('dev_per', per, global_step) summary.summary_scalar('dev_cer', cer, global_step) summary.summary_scalar('dev_wer', wer, global_step) if global_step % args.decode_step == args.decode_step - 1: # if True: decode_test(step=global_step, sample=args.dataset_test.uttid2sample( args.sample_uttid), model=G_infer, sess=sess, unit=args.data.unit, args=args) logging.info('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): dataset_dev = ASR_align_DataSet( file=[args.dirs.dev.data], args=args, _shuffle=False, transform=True) with tf.device("/cpu:0"): # wav data tfdata_train = TFData(dataset=None, dataAttr=['feature', 'label', 'align'], dir_save=args.dirs.train.tfdata, args=args).read(_shuffle=False) tfdata_dev = TFData(dataset=None, dataAttr=['feature', 'label', 'align'], dir_save=args.dirs.dev.tfdata, args=args).read(_shuffle=False) x_0, y_0, _ = next(iter(tfdata_train.take(args.num_supervised).map(lambda x, y, z: (x, y, z[:args.max_seq_len])).\ padded_batch(args.num_supervised, ([None, args.dim_input], [None], [None])))) iter_train = iter(tfdata_train.cache().repeat().shuffle(3000).map(lambda x, y, z: (x, y, z[:args.max_seq_len])).\ padded_batch(args.batch_size, ([None, args.dim_input], [None], [args.max_seq_len])).prefetch(buffer_size=3)) tfdata_dev = tfdata_dev.padded_batch(args.batch_size, ([None, args.dim_input], [None], [None])) # text data dataset_text = TextDataSet( list_files=[args.dirs.lm.data], args=args, _shuffle=True) tfdata_train_text = tf.data.Dataset.from_generator( dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = iter(tfdata_train_text.cache().repeat().shuffle(100).map(lambda x: x[:args.max_seq_len]).padded_batch(args.batch_size, ([args.max_seq_len])).prefetch(buffer_size=5)) # create model paremeters G = PhoneClassifier(args) D = PhoneDiscriminator2(args) G.summary() D.summary() optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.5, beta_2=0.9) optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr, beta_1=0.5, beta_2=0.9) optimizer = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.5, beta_2=0.9) writer = tf.summary.create_file_writer(str(args.dir_log)) ckpt = tf.train.Checkpoint(G=G, optimizer_G = optimizer_G) ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=5) step = 0 # if a checkpoint exists, restore the latest checkpoint. if args.dirs.checkpoint: _ckpt_manager = tf.train.CheckpointManager(ckpt, args.dirs.checkpoint, max_to_keep=1) ckpt.restore(_ckpt_manager.latest_checkpoint) print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint)) step = int(_ckpt_manager.latest_checkpoint.split('-')[-1]) start_time = datetime.now() num_processed = 0 progress = 0 while step < 99999999: start = time() for _ in range(args.opti.D_G_rate): x, _, aligns = next(iter_train) text = next(iter_text) P_Real = tf.one_hot(text, args.dim_output) cost_D, gp = train_D(x, aligns, P_Real, text>0, G, D, optimizer_D, args.lambda_gp) x, _, aligns = next(iter_train) cost_G, fs = train_G(x, aligns, G, D, optimizer_G, args.lambda_fs) loss_supervise = train_G_supervised(x_0, y_0, G, optimizer_G, args.dim_output) num_processed += len(x) if step % 10 == 0: print('cost_G: {:.3f}|{:.3f}\tcost_D: {:.3f}|{:.3f}\tloss_supervise: {:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'.format( cost_G, fs, cost_D, gp, loss_supervise, x.shape, text.shape, time()-start, progress*100.0, step)) with writer.as_default(): tf.summary.scalar("costs/cost_G", cost_G, step=step) tf.summary.scalar("costs/cost_D", cost_D, step=step) tf.summary.scalar("costs/gp", gp, step=step) tf.summary.scalar("costs/fs", fs, step=step) tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step) if step % args.dev_step == 0: fer, cer = evaluation(tfdata_dev, args.data.dev_size, G) with writer.as_default(): tf.summary.scalar("performance/fer", fer, step=step) tf.summary.scalar("performance/cer", cer, step=step) if step % args.decode_step == 0: decode(dataset_dev[0], G) if step % args.save_step == 0: save_path = ckpt_manager.save(step) print('save model {}'.format(save_path)) step += 1 print('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
def Train(): args.data.untrain_size = TFData.read_tfdata_info( args.dirs.untrain.tfdata)['size_dataset'] with tf.device("/cpu:0"): dataset_train = ASR_align_ArkDataSet(scp_file=args.dirs.train.scp, trans_file=args.dirs.train.trans, align_file=None, feat_len_file=None, args=args, _shuffle=False, transform=False) dataset_untrain = ASR_align_ArkDataSet(scp_file=args.dirs.untrain.scp, trans_file=None, align_file=None, feat_len_file=None, args=args, _shuffle=False, transform=False) dataset_dev = ASR_align_ArkDataSet(scp_file=args.dirs.dev.scp, trans_file=args.dirs.dev.trans, align_file=None, feat_len_file=None, args=args, _shuffle=False, transform=False) # wav data feature_train = TFData(dataset=dataset_train, dir_save=args.dirs.train.tfdata, args=args).read() feature_unsupervise = TFData(dataset=dataset_untrain, dir_save=args.dirs.untrain.tfdata, args=args).read() feature_dev = TFData(dataset=dataset_dev, dir_save=args.dirs.dev.tfdata, args=args).read() bucket = tf.data.experimental.bucket_by_sequence_length( element_length_func=lambda uttid, x: tf.shape(x)[0], bucket_boundaries=args.list_bucket_boundaries, bucket_batch_sizes=args.list_batch_size, padded_shapes=((), [None, args.dim_input])) iter_feature_train = iter( feature_train.repeat().shuffle(100).apply(bucket).prefetch( buffer_size=5)) # iter_feature_unsupervise = iter(feature_unsupervise.repeat().shuffle(100).apply(bucket).prefetch(buffer_size=5)) # iter_feature_train = iter(feature_train.repeat().shuffle(100).padded_batch(args.batch_size, # ((), [None, args.dim_input])).prefetch(buffer_size=5)) iter_feature_unsupervise = iter( feature_unsupervise.repeat().shuffle(100).padded_batch( args.batch_size, ((), [None, args.dim_input])).prefetch(buffer_size=5)) # feature_dev = feature_dev.apply(bucket).prefetch(buffer_size=5) feature_dev = feature_dev.padded_batch(args.batch_size, ((), [None, args.dim_input])) dataset_text = TextDataSet(list_files=[args.dirs.lm.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map( lambda x: x[:args.model.D.max_label_len]).padded_batch( args.text_batch_size, ([args.model.D.max_label_len])).prefetch(buffer_size=5)) # create model paremeters encoder = Encoder(args) decoder = Decoder(args) D = CLM(args) encoder.summary() decoder.summary() D.summary() optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9) optimizer_D = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9) writer = tf.summary.create_file_writer(str(args.dir_log)) ckpt_G = tf.train.Checkpoint(encoder=encoder, decoder=decoder) ckpt_manager = tf.train.CheckpointManager(ckpt_G, args.dir_checkpoint, max_to_keep=20) step = 0 if args.dirs.checkpoint_G: _ckpt_manager = tf.train.CheckpointManager(ckpt_G, args.dirs.checkpoint_G, max_to_keep=1) ckpt_G.restore(_ckpt_manager.latest_checkpoint) print('checkpoint_G {} restored!!'.format( _ckpt_manager.latest_checkpoint)) # cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, encoder, decoder) # with writer.as_default(): # tf.summary.scalar("performance/cer", cer, step=step) start_time = datetime.now() num_processed = 0 while step < 99999999: start = time() # supervise training uttids, x = next(iter_feature_train) trans = dataset_train.get_attrs('trans', uttids.numpy()) loss_supervise = train_CTC_supervised(x, trans, encoder, decoder, optimizer) # unsupervise training text = next(iter_text) _, un_x = next(iter_feature_unsupervise) # loss_G = train_G(un_x, encoder, decoder, D, optimizer, args.model.D.max_label_len) loss_G = train_G(un_x, encoder, decoder, D, optimizer, args.model.D.max_label_len) loss_D = train_D(un_x, text, encoder, decoder, D, optimizer_D, args.lambda_gp, args.model.D.max_label_len) num_processed += len(un_x) progress = num_processed / args.data.untrain_size if step % 10 == 0: print( 'loss_supervise: {:.3f}\tloss_G: {:.3f}\tloss_D: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% step: {}' .format(loss_supervise, loss_G, loss_D, un_x.shape, time() - start, progress * 100, step)) with writer.as_default(): tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step) if step % args.dev_step == args.dev_step - 1: cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, encoder, decoder) with writer.as_default(): tf.summary.scalar("performance/cer", cer, step=step) if step % args.decode_step == 0: monitor(dataset_dev[0], encoder, decoder) if step % args.save_step == 0: save_path = ckpt_manager.save(step) print('save model {}'.format(save_path)) step += 1 print('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(Model): # create dataset and dataloader dataset_train = TextDataSet(list_files=[args.dirs.train.data], args=args, _shuffle=True) dataset_dev = TextDataSet(list_files=[args.dirs.dev.data], args=args, _shuffle=False) args.data.train_size = len(dataset_train) args.data.dev_size = len(dataset_dev) tfdata_train = tf.data.Dataset.from_generator(dataset_train, (tf.int32), (tf.TensorShape([None]))) tfdata_dev = tf.data.Dataset.from_generator(dataset_dev, (tf.int32), (tf.TensorShape([None]))) tfdata_train = tfdata_train.cache().repeat().shuffle(500).padded_batch( args.batch_size, ([None])).prefetch(buffer_size=5) tfdata_dev = tfdata_dev.padded_batch(args.batch_size, ([None])) # build optimizer warmup = warmup_exponential_decay(warmup_steps=args.opti.warmup_steps, peak=args.opti.peak, decay_steps=args.opti.decay_steps) optimizer = tf.keras.optimizers.Adam(warmup, beta_1=0.9, beta_2=0.98, epsilon=1e-9) # create model paremeters model = Model(args) model.summary() # save & reload ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=20) if args.dirs.restore: latest_checkpoint = tf.train.CheckpointManager( ckpt, args.dirs.restore, max_to_keep=1).latest_checkpoint ckpt.restore(latest_checkpoint) print('{} restored!!'.format(latest_checkpoint)) start_time = datetime.now() get_data_time = 0 num_processed = 0 progress = 0 for global_step, batch in enumerate(tfdata_train): x, y = batch run_model_time = time() loss = train_step(x, y, model, optimizer) num_processed += len(x) get_data_time = run_model_time - get_data_time run_model_time = time() - run_model_time progress = num_processed / args.data.train_size if global_step % 10 == 0: print( 'loss: {:.5f}\t batch: {} lr:{:.6f} time: {:.2f}|{:.2f} s {:.3f}% step: {}' .format(loss, x.shape, warmup(global_step * 1.0).numpy(), get_data_time, run_model_time, progress * 100.0, global_step)) get_data_time = time() if global_step % args.dev_step == 0: evaluation(tfdata_dev, model) # if global_step % args.decode_step == 0: # decode(model) if global_step % args.save_step == 0: ckpt_manager.save() print('training duration: {:.2f}h'.format( (datetime.now() - start_time).total_seconds() / 3600))
def train(): dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True) tfdata_train = tf.data.Dataset.from_generator( dataset_text, (tf.int32), (tf.TensorShape([None]))) iter_text = tfdata_train.cache().repeat().shuffle(10000).\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataset_text_supervise = TextDataSet(list_files=[args.dirs.text.supervise], args=args, _shuffle=True) tfdata_supervise = tf.data.Dataset.from_generator( dataset_text_supervise, (tf.int32), (tf.TensorShape([None]))) iter_supervise = tfdata_supervise.cache().repeat().shuffle(100).\ padded_batch(args.num_supervised, ([args.max_label_len])).prefetch(buffer_size=5).\ make_one_shot_iterator().get_next() dataset_text_dev = TextDataSet(list_files=[args.dirs.text.dev], args=args, _shuffle=False) tfdata_dev = tf.data.Dataset.from_generator( dataset_text_dev, (tf.int32), (tf.TensorShape([None]))) tfdata_dev = tfdata_dev.cache().\ padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\ make_initializable_iterator() iter_text_dev = tfdata_dev.get_next() tensor_global_step = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False) tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False) G = Generator(tensor_global_step, hidden=args.model.hidden_size, num_blocks=args.model.num_blocks, training=True, args=args) G_infer = Generator(tensor_global_step, hidden=args.model.hidden_size, num_blocks=args.model.num_blocks, training=False, args=args) vars_G = G.trainable_variables D = args.Model_D(tensor_global_step1, training=True, name='discriminator', args=args) gan = Conditional_GAN([tensor_global_step0, tensor_global_step1], G, D, batch=None, unbatch=None, name='text_gan', args=args) size_variables() start_time = datetime.now() saver_G = tf.train.Saver(vars_G, max_to_keep=1) saver = tf.train.Saver(max_to_keep=15) summary = Summary(str(args.dir_log)) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True config.log_device_placement = False with tf.train.MonitoredTrainingSession(config=config) as sess: if args.dirs.checkpoint_G: saver_G.restore(sess, args.dirs.checkpoint_G) print('revocer G from', args.dirs.checkpoint_G) # np.random.seed(0) text_supervise = sess.run(iter_supervise) text_len_supervise = get_batch_length(text_supervise) feature_supervise, feature_len_supervise = int2vector(text_supervise, text_len_supervise, hidden_size=args.model.dim_input , uprate=args.uprate) feature_supervise += np.random.randn(*feature_supervise.shape)/args.noise batch_time = time() global_step = 0 # for _ in range(100): # np.random.seed(1) # text_G = sess.run(iter_text) # text_lens_G = get_batch_length(text_G) # feature_text, text_lens_G = int2vector(text_G, text_lens_G, hidden_size=args.model.dim_input , uprate=args.uprate) # feature_text += np.random.randn(*feature_text.shape)/args.noise # loss_G, loss_G_supervise, _ = sess.run(gan.list_train_G, # feed_dict={gan.list_G_pl[0]:feature_text, # gan.list_G_pl[1]:text_lens_G, # gan.list_G_pl[2]:feature_supervise, # gan.list_G_pl[3]:feature_len_supervise, # gan.list_G_pl[4]:text_supervise, # gan.list_G_pl[5]:text_len_supervise}) # saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=0, write_meta_graph=True) while global_step < 99999999: # global_step, lr_G, lr_D = sess.run([tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D]) global_step, lr_G = sess.run([tensor_global_step0, G.learning_rate]) # text_supervise = sess.run(iter_supervise) # text_len_supervise = get_batch_length(text_supervise) # feature_supervise, feature_len_supervise = int2vector(text_supervise, text_len_supervise, hidden_size=args.model.dim_input , uprate=args.uprate) # feature_supervise += np.random.randn(*feature_supervise.shape)/args.noise # supervise # for _ in range(1): # loss_G_supervise, _ = sess.run(G.run_list, # feed_dict={G.list_pl[0]:feature_text_supervise, # G.list_pl[1]:text_len_supervise, # G.list_pl[2]:text_supervise, # G.list_pl[3]:text_len_supervise}) # generator input text_G = sess.run(iter_text) text_lens_G = get_batch_length(text_G) feature_text, text_lens_G = int2vector(text_G, text_lens_G, hidden_size=args.model.dim_input , uprate=args.uprate) feature_text += np.random.randn(*feature_text.shape)/args.noise loss_G, loss_G_supervise, _ = sess.run(gan.list_train_G, feed_dict={gan.list_G_pl[0]:feature_text, gan.list_G_pl[1]:text_lens_G, gan.list_G_pl[2]:feature_supervise, gan.list_G_pl[3]:feature_len_supervise, gan.list_G_pl[4]:text_supervise, gan.list_G_pl[5]:text_len_supervise}) # loss_G = loss_G_supervise = 0 # discriminator input for _ in range(3): # np.random.seed(2) text_G = sess.run(iter_text) text_lens_G = get_batch_length(text_G) feature_G, feature_lens_G = int2vector(text_G, text_lens_G, hidden_size=args.model.dim_input, uprate=args.uprate) feature_G += np.random.randn(*feature_G.shape)/args.noise text_D = sess.run(iter_text) text_lens_D = get_batch_length(text_D) shape_text = text_D.shape loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D, feed_dict={gan.list_D_pl[0]:text_D, gan.list_D_pl[1]:text_lens_D, gan.list_G_pl[0]:feature_G, gan.list_G_pl[1]:feature_lens_G}) # loss_D_res = loss_D_text = loss_gp = 0 # shape_text = [0,0,0] # loss_D_res = - loss_G # loss_G = loss_G_supervise = 0.0 # loss_D = loss_D_text = loss_gp = 0.0 # train # if global_step % 5 == 0: # for _ in range(2): # loss_supervise, shape_batch, _, _ = sess.run(G.list_run) # loss_G_supervise = 0 used_time = time()-batch_time batch_time = time() if global_step % 10 == 0: # print('loss_G: {:.2f} loss_G_supervise: {:.2f} loss_D_res: {:.2f} loss_D_text: {:.2f} step: {}'.format( # loss_G, loss_G_supervise, loss_D_res, loss_D_text, global_step)) print('loss_G_supervise: {:.2f} loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}\tlr:{:.1e} {:.2f}s step: {}'.format( loss_G_supervise, loss_D_res, loss_D_text, loss_gp, shape_text, lr_G, used_time, global_step)) # summary.summary_scalar('loss_G', loss_G, global_step) # summary.summary_scalar('loss_D', loss_D, global_step) # summary.summary_scalar('lr_G', lr_G, global_step) # summary.summary_scalar('lr_D', lr_D, global_step) if global_step % args.save_step == args.save_step - 1: saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=global_step, write_meta_graph=True) print('saved model as', str(args.dir_checkpoint) + '/model-' + str(global_step)) # if global_step % args.dev_step == args.dev_step - 1: if global_step % args.dev_step == 1: text_G_dev = dev(iter_text_dev, tfdata_dev, dataset_text_dev, sess, G_infer) # if global_step % args.decode_step == args.decode_step - 1: if global_step % args.decode_step == args.decode_step - 1: decode(text_G_dev, tfdata_dev, sess, G_infer) logging.info('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))