def main(argv): session_config = tf.ConfigProto() session_config.gpu_options.allow_growth = True run_config = tf.estimator.RunConfig().replace( session_config=session_config) train_feeder = DataFeeder(['train_random_video.txt'] + [ 'train_random_video_04{}.txt'.format(n) for n in [13, 15, 16, 17, 18, 19, 20, 21] ]) test_feeder = DataFeeder(['train_HRS_video.txt'], is_training=False) classifier = tf.estimator.Estimator(model_fn=my_model, model_dir='log8_moredata_cos', config=run_config, params=train_feeder) # train_spec = tf.estimator.TrainSpec( input_fn=lambda: [train_feeder.tensor_data_generator(128, 100), None]) test_spec = tf.estimator.EvalSpec( input_fn=lambda: [test_feeder.tensor_data_generator(128, 100), None], throttle_secs=600) tf.estimator.train_and_evaluate(classifier, train_spec, test_spec)
def train(log_dir, args): checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = './jenie_Processed/amused/' log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) restore_path = './Trained Weights/model.ckpt-lj' print(restore_path) saver.restore(sess, restore_path) log('[INFO] Resuming from checkpoint: %s ' % (restore_path)) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step)) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, step=%d, loss=%.5f' % (args.model, time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e) traceback.print_exc() coord.request_stop(e)
def train(log_dir, input_path, checkpoint_path, is_restore): # Log the info log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model('tacotron', hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if is_restore: # Restore from a checkpoint if the user requested it. restore_path = '%s' % (checkpoint_path) saver.restore(sess, restore_path) log('Resuming from checkpoint') else: log('Starting new training') feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_interval = time.time() - start_time message = 'Step %d, %.03f sec, loss=%.05f' % (step, loss, time_interval) log(message) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % hparams.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % hparams.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, step=%d, loss=%.5f' % ('tacotron', time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) coord.request_stop(e)
def train(logdir,args): # TODO:parse ckpt,arguments,hparams checkpoint_path = os.path.join(logdir,'model.ckpt') input_path = args.data_dir log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from : %s ' % input_path) log('Using model : %s' %args.model) # TODO:set up datafeeder with tf.variable_scope('datafeeder') as scope: hp.data_length = None hp.initial_learning_rate = 0.0001 hp.batch_size = 256 hp.prime = True hp.stcmd = True feeder = DataFeeder(args=hp) log('num_sentences:'+str(len(feeder.wav_lst))) # 283600 hp.input_vocab_size = len(feeder.pny_vocab) hp.final_output_dim = len(feeder.pny_vocab) hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size log('steps_per_epoch:' + str(hp.steps_per_epoch)) # 17725 log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292 hp.label_vocab_size = len(feeder.han_vocab) log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291 # TODO:set up model global_step = tf.Variable(initial_value=0,name='global_step',trainable=False) valid_step = 0 # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model,hp) model.build_graph() model.add_loss() model.add_optimizer(global_step=global_step,loss=model.mean_loss) # TODO: summary stats = add_stats(model=model) valid_stats = add_dev_stats(model) # TODO:Set up saver and Bookkeeping time_window = ValueWindow(100) loss_window = ValueWindow(100) acc_window = ValueWindow(100) valid_time_window = ValueWindow(100) valid_loss_window = ValueWindow(100) valid_acc_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=20) first_serving = True # TODO: train with tf.Session() as sess: log(hparams_debug_string(hp)) try: # TODO: Set writer and initializer summary_writer = tf.summary.FileWriter(logdir + '/train', sess.graph) summary_writer_dev = tf.summary.FileWriter(logdir + '/dev') sess.run(tf.global_variables_initializer()) # TODO: Restore if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s ' % restore_path) else: log('Starting new training run ') step = 0 # TODO: epochs steps batch for i in range(args.epochs): batch_data = feeder.get_lm_batch() log('Traning epoch '+ str(i)+':') for j in range(hp.steps_per_epoch): input_batch, label_batch = next(batch_data) feed_dict = { model.x:input_batch, model.y:label_batch, } # TODO: Run one step ~~~ start_time = time.time() total_step,batch_loss,batch_acc,opt = sess.run([global_step, model.mean_loss,model.acc,model.optimize],feed_dict=feed_dict) time_window.append(time.time() - start_time) step = total_step # TODO: Append loss loss_window.append(batch_loss) acc_window.append(batch_acc) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f,acc=%.05f, avg_acc=%.05f, lr=%.07f]' % ( step, time_window.average, batch_loss, loss_window.average,batch_acc,acc_window.average,K.get_value(model.learning_rate)) log(message) # TODO: Check loss if math.isnan(batch_loss): log('Loss exploded to %.05f at step %d!' % (batch_loss, step)) raise Exception('Loss Exploded') # TODO: Check sumamry if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step) # TODO: Check checkpoint if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('test acc...') label,final_pred_label = sess.run([ model.y, model.preds],feed_dict=feed_dict) log('label.shape :'+str(label.shape)) # (batch_size , label_length) log('final_pred_label.shape:'+str(np.asarray(final_pred_label).shape)) # (1, batch_size, decode_length<=label_length) log('label : '+str(label[0])) log('final_pred_label: '+str( np.asarray(final_pred_label)[0])) # TODO: serving if args.serving :#and total_step // hp.steps_per_epoch > 5: np.save('logdir/lm_pinyin_dict.npy',feeder.pny_vocab) np.save('logdir/lm_hanzi_dict.npy',feeder.han_vocab) print(total_step, 'hhhhhhhh') # TODO: Set up serving builder and signature map serve_dir = args.serving_dir + '0001' if os.path.exists(serve_dir): os.removedirs(serve_dir) builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir) input = tf.saved_model.utils.build_tensor_info(model.x) output_labels = tf.saved_model.utils.build_tensor_info(model.preds) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'pinyin': input}, outputs={'hanzi': output_labels}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) ) if first_serving: first_serving = False builder.add_meta_graph_and_variables( sess=sess, tags=[tf.saved_model.tag_constants.SERVING], signature_def_map={ 'predict_Pinyin2Hanzi': prediction_signature, }, main_op=tf.tables_initializer(), strip_default_attrs=True ) builder.save() log('Done store serving-model') raise Exception('Done store serving-model') # TODO: Validation # if total_step % hp.steps_per_epoch == 0 and i >= 10: if total_step % hp.steps_per_epoch == 0: log('validation...') valid_start = time.time() # TODO: validation valid_hp = copy.deepcopy(hp) print('feature_type: ',hp.feature_type) valid_hp.data_type = 'dev' valid_hp.thchs30 = True valid_hp.aishell = True valid_hp.prime = True valid_hp.stcmd = True valid_hp.shuffle = True valid_hp.data_length = None valid_feeder = DataFeeder(args=valid_hp) valid_feeder.pny_vocab = feeder.pny_vocab valid_feeder.han_vocab = feeder.han_vocab # valid_feeder.am_vocab = feeder.am_vocab valid_batch_data = valid_feeder.get_lm_batch() log('valid_num_sentences:' + str(len(valid_feeder.wav_lst))) # 15219 valid_hp.input_vocab_size = len(valid_feeder.pny_vocab) valid_hp.final_output_dim = len(valid_feeder.pny_vocab) valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951 log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124 valid_hp.label_vocab_size = len(valid_feeder.han_vocab) log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327 # dev 只跑一个epoch就行 with tf.variable_scope('validation') as scope: for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size): valid_input_batch,valid_label_batch = next(valid_batch_data) valid_feed_dict = { model.x: valid_input_batch, model.y: valid_label_batch, } # TODO: Run one step valid_start_time = time.time() valid_batch_loss,valid_batch_acc = sess.run([model.mean_loss,model.acc], feed_dict=valid_feed_dict) valid_time_window.append(time.time() - valid_start_time) valid_loss_window.append(valid_batch_loss) valid_acc_window.append(valid_batch_acc) # print('loss',loss,'batch_loss',batch_loss) message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, valid_acc=%.05f, avg_acc=%.05f]' % ( valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_batch_acc,valid_acc_window.average) log(message) summary_writer_dev.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step) valid_step += 1 log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start)) except Exception as e: log('Exiting due to exception: %s' % e) traceback.print_exc()
def generate_feed_dict(config,tower_data): train_mod = math.ceil(len(tower_data['input_ids']) / config['model']['gpu_num']) for k in range(config['model']['gpu_num']): start = k * train_mod end = start + train_mod for key in tower_data: print(np.array(tower_data[key][start:end]).shape) config ={'data':{'batch_size':10, 'train_dataset_filePath':'/datastore/liu121/bert_trail/train_data/train_data_%d.pkl', 'train_file_num':5 }, 'model':{'gpu_num':2}} for i in range(config['data']['train_file_num']): df = DataFeeder(config,i) dataset = df.dataset_generator() print('total dataset size: ',np.array(df.train_input_ids).shape[0]) exit() for input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels in dataset: print('batch size: ',np.array(input_ids).shape) tower_data = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, 'masked_lm_positions': masked_lm_positions, 'masked_lm_ids': masked_lm_ids, 'masked_lm_weights': masked_lm_weights, 'next_sentence_labels': next_sentence_labels} # generate_feed_dict(config,tower_data) print('=====================================') exit()
infolog.set_tf_log(load_path) # Estimator --> log 저장 tf.logging.set_verbosity(tf.logging.INFO) # 이게 있어야 train log가 출력된다. # load data inputs, targets, word_to_index, index_to_word, VOCAB_SIZE, INPUT_LENGTH, OUTPUT_LENGTH = load_data( hp) # (50000, 29), (50000, 12) hp.add_hparam('VOCAB_SIZE', VOCAB_SIZE) hp.add_hparam('INPUT_LENGTH', INPUT_LENGTH) # 29 hp.add_hparam('OUTPUT_LENGTH', OUTPUT_LENGTH) # 11 train_input, test_input, train_target, test_target = train_test_split( inputs, targets, test_size=0.1, random_state=13371447) datfeeder = DataFeeder(train_input, train_target, test_input, test_target, batch_size=hp.BATCH_SIZE, num_epoch=hp.NUM_EPOCHS) def seq_accuracy( labels, predictions ): #tf.estimator.EstimatorSpec에서 넘겨 받는 값 + features + labels ---> 이런 값들을 argument로 받을 수 있다. return { 'seq_accuracy': tf.metrics.mean( tf.reduce_prod(tf.cast( tf.equal(predictions['predition'], labels[:, 1:]), tf.float16), axis=-1)) }
config = tf.ConfigProto() # occupy gpu gracefully config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: # load graph new_saver = tf.train.import_meta_graph('log3/model.ckpt-1010600.meta') # load values new_saver.restore(sess, tf.train.latest_checkpoint('log3/')) # nodes graph = tf.get_default_graph() content_tensor = graph.get_tensor_by_name('IteratorGetNext:2') dan_tensor = graph.get_tensor_by_name('fc/fc_dan/BiasAdd:0') cdssm_tensor = graph.get_tensor_by_name('fc/fc_cdssm/BiasAdd:0') feeder = DataFeeder('train_hrs_video.txt') with open('pred_v3.txt', 'w') as f: for i, (key, dan, cdssm, content) in enumerate(feeder.read_feature()): dan = np.array(dan).astype(int) cdssm = np.array(cdssm).astype(int) content = np.expand_dims(content, 0) dan_v, cdssm_v = sess.run([dan_tensor, cdssm_tensor], feed_dict={content_tensor: content}) dan_v = np.round(np.minimum(255, np.maximum(0, dan_v[0]))).astype(int) cdssm_v = np.round(np.minimum(255, np.maximum(0, cdssm_v[0]))).astype(int) f.write(key + '\t' + '\t'.join(str(x) for x in dan_v) + '\t' + '\t'.join(str(x) for x in cdssm_v) + '\n') if i % 10000 == 0:
def train(logdir,args): # TODO:parse ckpt,arguments,hparams checkpoint_path = os.path.join(logdir,'model.ckpt') input_path = args.data_dir log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from : %s ' % input_path) log('Using model : %s' %args.model) # TODO:set up datafeeder with tf.variable_scope('datafeeder') as scope: hp.aishell=True hp.prime=False hp.stcmd=False hp.data_path = 'D:/pycharm_proj/corpus_zn/' hp.initial_learning_rate = 0.001 hp.decay_learning_rate=False hp.data_length = 512 hp.batch_size = 64 feeder = DataFeeder(args=hp) log('num_wavs:'+str(len(feeder.wav_lst))) # 283600 hp.input_vocab_size = len(feeder.pny_vocab) hp.final_output_dim = len(feeder.pny_vocab) hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size log('steps_per_epoch:' + str(hp.steps_per_epoch)) # 17725 log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292 hp.label_vocab_size = len(feeder.han_vocab) log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291 # TODO:set up model global_step = tf.Variable(initial_value=0,name='global_step',trainable=False) valid_step = 0 # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model,hp) model.build_graph() model.add_loss() model.add_decoder() model.add_optimizer(global_step=global_step) # TODO: summary stats = add_stats(model=model) valid_stats = add_dev_stats(model) # TODO:Set up saver and Bookkeeping time_window = ValueWindow(100) loss_window = ValueWindow(100) wer_window = ValueWindow(100) valid_time_window = ValueWindow(100) valid_loss_window = ValueWindow(100) valid_wer_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=20) first_serving = True # TODO: train with tf.Session() as sess: try: # TODO: Set writer and initializer summary_writer = tf.summary.FileWriter(logdir, sess.graph) sess.run(tf.global_variables_initializer()) # TODO: Restore if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s ' % restore_path) else: log('Starting new training run ') step = 0 # TODO: epochs steps batch for i in range(args.epochs): batch_data = feeder.get_am_batch() log('Traning epoch '+ str(i)+':') for j in range(hp.steps_per_epoch): input_batch = next(batch_data) feed_dict = {model.inputs:input_batch['the_inputs'], model.labels:input_batch['the_labels'], model.input_lengths:input_batch['input_length'], model.label_lengths:input_batch['label_length']} # TODO: Run one step start_time = time.time() total_step, array_loss, batch_loss,opt = sess.run([global_step, model.ctc_loss, model.batch_loss,model.optimize],feed_dict=feed_dict) time_window.append(time.time() - start_time) step = total_step # TODO: Append loss # loss = np.sum(array_loss).item()/hp.batch_size loss_window.append(batch_loss) # print('loss',loss,'batch_loss',batch_loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, lr=%.07f]' % ( step, time_window.average, batch_loss, loss_window.average,K.get_value(model.learning_rate)) log(message) # ctcloss返回值是[batch_size,1]形式的所以刚开始没sum报错only size-1 arrays can be converted to Python scalars # TODO: Check loss if math.isnan(batch_loss): log('Loss exploded to %.05f at step %d!' % (batch_loss, step)) raise Exception('Loss Exploded') # TODO: Check sumamry if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step) # TODO: Check checkpoint if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('test acc...') # eval_start_time = time.time() # with tf.name_scope('eval') as scope: # with open(os.path.expanduser('~/my_asr2/datasets/resource/preprocessedData/dev-meta.txt'), encoding='utf-8') as f: # metadata = [line.strip().split('|') for line in f] # random.shuffle(metadata) # eval_loss = [] # batch_size = args.hp.batch_size # batchs = len(metadata)//batch_size # for i in range(batchs): # batch = metadata[i*batch_size : i*batch_size+batch_size] # batch = list(map(eval_get_example,batch)) # batch = eval_prepare_batch(batch) # feed_dict = {'labels':batch[0],} # label,final_pred_label ,log_probabilities = sess.run([ # model.labels[0], model.decoded[0], model.log_probabilities[0]]) # # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840 # print('label: ' ,label) # print('final_pred_label: ', final_pred_label[0]) # log('eval time: %.03f, avg_eval_loss: %.05f' % (time.time()-eval_start_time,np.mean(eval_loss))) label,final_pred_label ,log_probabilities,y_pred2 = sess.run([ model.labels, model.decoded, model.log_probabilities,model.y_pred2],feed_dict=feed_dict) # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840 print('label.shape :',label.shape) # (batch_size , label_length) print('final_pred_label.shape:',np.asarray(final_pred_label).shape) # (1, batch_size, decode_length<=label_length) print('y_pred2.shape : ', y_pred2.shape) print('label: ' ,label[0]) print('y_pred2 : ', y_pred2[0]) print('final_pred_label: ', np.asarray(final_pred_label)[0][0]) # 刚开始打不出来,因为使用的tf.nn.ctc_beam_decoder,这个返回的是sparse tensor所以打不出来 # 后来用keras封装的decoder自动将sparse转成dense tensor才能打出来 # waveform = audio.inv_spectrogram(spectrogram.T) # audio.save_wav(waveform, os.path.join(logdir, 'step-%d-audio.wav' % step)) # plot.plot_alignment(alignment, os.path.join(logdir, 'step-%d-align.png' % step), # info='%s, %s, %s, step=%d, loss=%.5f' % ( # args.model, commit, time_string(), step, loss)) # log('Input: %s' % sequence_to_text(input_seq)) # TODO: Check stop if step % hp.steps_per_epoch ==0: # TODO: Set up serving builder and signature map serve_dir = args.serving_dir + '_' + str(total_step//hp.steps_per_epoch -1) if os.path.exists(serve_dir): os.removedirs(serve_dir) builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir) input_spec = tf.saved_model.utils.build_tensor_info(model.inputs) input_len = tf.saved_model.utils.build_tensor_info(model.input_lengths) output_labels = tf.saved_model.utils.build_tensor_info(model.decoded2) output_logits = tf.saved_model.utils.build_tensor_info(model.pred_logits) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'spec': input_spec, 'len': input_len}, outputs={'label': output_labels, 'logits': output_logits}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) ) if first_serving: first_serving = False builder.add_meta_graph_and_variables( sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'], signature_def_map={ 'predict_AudioSpec2Pinyin': prediction_signature, }, main_op=tf.tables_initializer(), strip_default_attrs=True ) else: builder.add_meta_graph_and_variables( sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'], signature_def_map={ 'predict_AudioSpec2Pinyin': prediction_signature, }, strip_default_attrs=True ) builder.save() log('Done store serving-model') # TODO: Validation if step % hp.steps_per_epoch ==0 and i >= 10: log('validation...') valid_start = time.time() # TODO: validation valid_hp = hp valid_hp.data_type = 'dev' valid_hp.thchs30 = True valid_hp.aishell = False valid_hp.prime = False valid_hp.stcmd = False valid_hp.shuffle = True valid_hp.data_length = None valid_feeder = DataFeeder(args=valid_hp) valid_batch_data = valid_feeder.get_am_batch() log('valid_num_wavs:' + str(len(valid_feeder.wav_lst))) # 15219 valid_hp.input_vocab_size = len(valid_feeder.pny_vocab) valid_hp.final_output_dim = len(valid_feeder.pny_vocab) valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951 log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124 valid_hp.label_vocab_size = len(valid_feeder.han_vocab) log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327 words_num = 0 word_error_num = 0 # dev 只跑一个epoch就行 with tf.variable_scope('validation') as scope: for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size): valid_input_batch = next(valid_batch_data) valid_feed_dict = {model.inputs: valid_input_batch['the_inputs'], model.labels: valid_input_batch['the_labels'], model.input_lengths: valid_input_batch['input_length'], model.label_lengths: valid_input_batch['label_length']} # TODO: Run one step valid_start_time = time.time() valid_batch_loss,valid_WER = sess.run([model.batch_loss,model.WER], feed_dict=valid_feed_dict) valid_time_window.append(time.time() - valid_start_time) valid_loss_window.append(valid_batch_loss) valid_wer_window.append(valid_WER) # print('loss',loss,'batch_loss',batch_loss) message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, WER=%.05f, avg_WER=%.05f, lr=%.07f]' % ( valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_WER,valid_wer_window.average,K.get_value(model.learning_rate)) log(message) summary_writer.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step) valid_step += 1 log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start)) except Exception as e: log('Exiting due to exception: %s' % e) traceback.print_exc()
def train(args): if args.model_path is None: msg = 'Prepare for new run ...' output_dir = os.path.join( args.log_dir, args.run_name + '_' + datetime.datetime.now().strftime('%m%d_%H%M')) if not os.path.exists(output_dir): os.makedirs(output_dir) ckpt_dir = os.path.join( args.ckpt_dir, args.run_name + '_' + datetime.datetime.now().strftime('%m%d_%H%M')) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) else: msg = 'Restart previous run ...\nlogs to save to %s, ckpt to save to %s, model to load from %s' % \ (args.log_dir, args.ckpt_dir, args.model_path) output_dir = args.log_dir ckpt_dir = args.ckpt_dir if not os.path.isdir(output_dir): print('Invalid log dir: %s' % output_dir) return if not os.path.isdir(ckpt_dir): print('Invalid ckpt dir: %s' % ckpt_dir) return set_logger(os.path.join(output_dir, 'outputs.log')) logging.info(msg) global device if args.device is not None: logging.info('Setting device to ' + args.device) device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info('Setting up...') hparams.parse(args.hparams) logging.info(hparams_debug_string()) model = EdgeClassification if hparams.use_roberta: logging.info('Using Roberta...') model = RobertaEdgeClassification global_step = 0 if args.model_path is None: if hparams.load_pretrained: logging.info('Load online pretrained model...' + ( ('cached at ' + args.cache_path) if args.cache_path is not None else '')) if hparams.use_roberta: model = model.from_pretrained('roberta-base', cache_dir=args.cache_path, hparams=hparams) else: model = model.from_pretrained('bert-base-uncased', cache_dir=args.cache_path, hparams=hparams) else: logging.info('Build model from scratch...') if hparams.use_roberta: config = BertConfig.from_pretrained('bert-base-uncased') else: config = RobertaConfig.from_pretrained('roberta-base') model = model(config=config, hparams=hparams) else: if not os.path.isdir(args.model_path): raise OSError(str(args.model_path) + ' not found') logging.info('Load saved model from %s ...' % (args.model_path)) model = model.from_pretrained(args.model_path, hparams=hparams) step = args.model_path.split('_')[-1] if step.isnumeric(): global_step = int(step) logging.info('Initial step=%d' % global_step) if hparams.use_roberta: tokenizer = RobertaTokenizer.from_pretrained('roberta-base') else: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') hparams.parse(args.hparams) logging.info(hparams_debug_string()) if hparams.text_sample_eval: if args.eval_text_path is None: raise ValueError('eval_text_path not given') if ':' not in args.eval_text_path: eval_data_paths = [args.eval_text_path] else: eval_data_paths = args.eval_text_path.split(':') eval_feeder = [] for p in eval_data_paths: name = os.path.split(p)[-1] if name.endswith('.tsv'): name = name[:-4] eval_feeder.append( (name, ExternalTextFeeder(p, hparams, tokenizer, 'dev'))) else: eval_feeder = [('', DataFeeder(args.data_dir, hparams, tokenizer, 'dev'))] tb_writer = SummaryWriter(output_dir) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': hparams.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=hparams.learning_rate, eps=hparams.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=hparams.warmup_steps, lr_decay_step=hparams.lr_decay_step, max_lr_decay_rate=hparams.max_lr_decay_rate) acc_step = global_step * hparams.gradient_accumulation_steps time_window = ValueWindow() loss_window = ValueWindow() acc_window = ValueWindow() model.to(device) model.zero_grad() tr_loss = tr_acc = 0.0 start_time = time.time() if args.model_path is not None: logging.info('Load saved model from %s ...' % (args.model_path)) if os.path.exists(os.path.join(args.model_path, 'optimizer.pt')) \ and os.path.exists(os.path.join(args.model_path, 'scheduler.pt')): optimizer.load_state_dict( torch.load(os.path.join(args.model_path, 'optimizer.pt'))) optimizer.load_state_dict(optimizer.state_dict()) scheduler.load_state_dict( torch.load(os.path.join(args.model_path, 'scheduler.pt'))) scheduler.load_state_dict(scheduler.state_dict()) else: logging.warning('Could not find saved optimizer/scheduler') if global_step > 0: logs = run_eval(args, model, eval_feeder) for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logging.info('Start training...') if hparams.text_sample_train: train_feeder = PrebuiltTrainFeeder(args.train_text_path, hparams, tokenizer, 'train') else: train_feeder = DataFeeder(args.data_dir, hparams, tokenizer, 'train') while True: batch = train_feeder.next_batch() model.train() outputs = model(input_ids=batch.input_ids.to(device), attention_mask=batch.input_mask.to(device), token_type_ids=None if batch.token_type_ids is None else batch.token_type_ids.to(device), labels=batch.labels.to(device)) loss = outputs['loss'] preds = outputs['preds'] acc = torch.mean((preds.cpu() == batch.labels).float()) preds = preds.cpu().detach().numpy() labels = batch.labels.detach().numpy() t_acc = np.sum(np.logical_and(preds == 1, labels == 1)) / np.sum(labels == 1) f_acc = np.sum(np.logical_and(preds == 0, labels == 0)) / np.sum(labels == 0) if hparams.gradient_accumulation_steps > 1: loss = loss / hparams.gradient_accumulation_steps acc = acc / hparams.gradient_accumulation_steps tr_loss += loss.item() tr_acc += acc.item() loss.backward() acc_step += 1 if acc_step % hparams.gradient_accumulation_steps != 0: continue torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.max_grad_norm) optimizer.step() scheduler.step(None) model.zero_grad() global_step += 1 step_time = time.time() - start_time time_window.append(step_time) loss_window.append(tr_loss) acc_window.append(tr_acc) if global_step % args.save_steps == 0: # Save model checkpoint model_to_save = model.module if hasattr(model, 'module') else model cur_ckpt_dir = os.path.join(ckpt_dir, 'checkpoint_%d' % (global_step)) if not os.path.exists(cur_ckpt_dir): os.makedirs(cur_ckpt_dir) model_to_save.save_pretrained(cur_ckpt_dir) torch.save(args, os.path.join(cur_ckpt_dir, 'training_args.bin')) torch.save(optimizer.state_dict(), os.path.join(cur_ckpt_dir, 'optimizer.pt')) torch.save(scheduler.state_dict(), os.path.join(cur_ckpt_dir, 'scheduler.pt')) logging.info("Saving model checkpoint to %s", cur_ckpt_dir) if global_step % args.logging_steps == 0: logs = run_eval(args, model, eval_feeder) learning_rate_scalar = scheduler.get_lr()[0] logs['learning_rate'] = learning_rate_scalar logs['loss'] = loss_window.average logs['acc'] = acc_window.average for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f, t_acc=%.05f, f_acc=%.05f]' % ( global_step, step_time, tr_loss, loss_window.average, tr_acc, acc_window.average, t_acc, f_acc) logging.info(message) tr_loss = tr_acc = 0.0 start_time = time.time()