def train_simple(sess): """ Training Mode: Create a new model and train the network """ tf.set_random_seed(seed) ### Prepare data for training vocab_dict, word_embedding_array = DataProcessor( ).prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training") # data in whole batch with padded matrixes val_batch = batch_load_data(DataProcessor().prepare_news_data( vocab_dict, data_type="validation")) fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1 if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0: q = int(FLAGS.sentembed_size // fil_lens_to_test) FLAGS.sentembed_size = q * fil_lens_to_test print("corrected embedding size: %d" % FLAGS.sentembed_size) # Create Model with various operations model = MY_Model(sess, len(vocab_dict) - 2) init_epoch = 1 # Resume training if indicated Select the model if FLAGS.model_to_load != -1: selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str( FLAGS.model_to_load) init_epoch = FLAGS.model_to_load + 1 print("Reading model parameters from %s" % selected_modelpath) model.saver.restore(sess, selected_modelpath) print("Model loaded.") # Initialize word embedding before training sess.run(model.vocab_embed_variable.assign(word_embedding_array)) ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training counter = 0 max_val_acc = -1 for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1): ep_time = time.time() # to check duration train_data.shuffle_fileindices() # Start Batch Training step = 1 total_loss = 0 while (step * FLAGS.batch_size) <= len(train_data.fileindices): # Get batch data as Numpy Arrays batch = train_data.get_batch(((step - 1) * FLAGS.batch_size), (step * FLAGS.batch_size)) # Run optimizer: optimize policy and reward estimator _, ce_loss = sess.run( [model.train_op_policynet_withgold, model.cross_entropy_loss], feed_dict={ model.document_placeholder: batch.docs, model.label_placeholder: batch.labels, model.weight_placeholder: batch.weights }) total_loss += ce_loss step += 1 #END-WHILE-TRAINING ... but wait there is more ## eval metrics FLAGS.authorise_gold_label = False prev_use_dpt = FLAGS.use_dropout total_loss /= step FLAGS.use_dropout = False # retrieve batch with updated logits in it val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess) FLAGS.authorise_gold_label = True FLAGS.use_dropout = prev_use_dpt probs = sess.run( model.predictions, feed_dict={model.logits_placeholder: val_batch.logits}) validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids) val_mrr = mrr_metric(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") val_map = map_score(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") print( "\tEpoch %2d || Train ce_loss: %4.3f || Val acc: %.4f || Val mrr: %.4f || Val mac: %.4f || duration: %3.2f" % (epoch, total_loss, validation_acc, val_mrr, val_map, time.time() - ep_time)) if FLAGS.save_models: print("Saving model after epoch completion") checkpoint_path = os.path.join( FLAGS.train_dir, "step-a.model.ckpt.epoch-" + str(epoch)) model.saver.save(sess, checkpoint_path) print( "------------------------------------------------------------------------------------------" ) #END-FOR-EPOCH print("Optimization Finished!")
def train(): """ Training Mode: Create a new model and train the network """ # Training: use the tf default graph with tf.Graph().as_default() and tf.device(FLAGS.use_gpu): config = tf.ConfigProto(allow_soft_placement=True) # Start a session with tf.Session(config=config) as sess: ### Prepare data for training print("Prepare vocab dict and read pretrained word embeddings ...") vocab_dict, word_embedding_array = DataProcessor( ).prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array print("Prepare training data ...") train_data = DataProcessor().prepare_news_data( vocab_dict, data_type="training") print("Prepare validation data ...") validation_data = DataProcessor().prepare_news_data( vocab_dict, data_type="validation") print("Prepare ROUGE reward generator ...") rouge_generator = Reward_Generator() # Create Model with various operations model = MY_Model(sess, len(vocab_dict) - 2) # Start training with some pretrained model start_epoch = 1 # selected_modelpath = FLAGS.train_dir+"/model.ckpt.epoch-"+str(start_epoch-1) # if not (os.path.isfile(selected_modelpath)): # print("Model not found in checkpoint folder.") # exit(0) # # Reload saved model and test # print("Reading model parameters from %s" % selected_modelpath) # model.saver.restore(sess, selected_modelpath) # print("Model loaded.") # Initialize word embedding before training print( "Initialize word embedding vocabulary with pretrained embeddings ..." ) sess.run(model.vocab_embed_variable.assign(word_embedding_array)) ########### Start (No Mixer) Training : Reinforcement learning ################ # Reward aware training as part of Reward weighted CE , # No Curriculam learning: No annealing, consider full document like in MRT # Multiple Samples (include gold sample), No future reward, Similar to MRT # During training does not use PYROUGE to avoid multiple file rewritings # Approximate MRT with multiple pre-estimated oracle samples # June 2017: Use Single sample from multiple oracles ############################################################################### print( "Start Reinforcement Training (single rollout from largest prob mass) ..." ) for epoch in range(start_epoch, FLAGS.train_epoch_wce + 1): print("MRT: Epoch " + str(epoch)) print("MRT: Epoch " + str(epoch) + " : Reshuffle training document indices") train_data.shuffle_fileindices() print("MRT: Epoch " + str(epoch) + " : Restore Rouge Dict") rouge_generator.restore_rouge_dict() # Start Batch Training step = 1 while (step * FLAGS.batch_size) <= len(train_data.fileindices): # Get batch data as Numpy Arrays batch_docnames, batch_docs, batch_label, batch_weight, batch_oracle_multiple, batch_reward_multiple = train_data.get_batch( ((step - 1) * FLAGS.batch_size), (step * FLAGS.batch_size)) # print(batch_docnames) # print(batch_label[0]) # print(batch_weight[0]) # print(batch_oracle_multiple[0]) # print(batch_reward_multiple[0]) # exit(0) # Print the progress if (step % FLAGS.training_checkpoint) == 0: ce_loss_val, ce_loss_sum, acc_val, acc_sum = sess.run( [ model. rewardweighted_cross_entropy_loss_multisample, model. rewardweighted_ce_multisample_loss_summary, model.accuracy, model.taccuracy_summary ], feed_dict={ model.document_placeholder: batch_docs, model.predicted_multisample_label_placeholder: batch_oracle_multiple, model.actual_reward_multisample_placeholder: batch_reward_multiple, model.label_placeholder: batch_label, model.weight_placeholder: batch_weight }) # Print Summary to Tensor Board model.summary_writer.add_summary( ce_loss_sum, ((epoch - 1) * len(train_data.fileindices) + step * FLAGS.batch_size)) model.summary_writer.add_summary( acc_sum, ((epoch - 1) * len(train_data.fileindices) + step * FLAGS.batch_size)) print( "MRT: Epoch " + str(epoch) + " : Covered " + str(step * FLAGS.batch_size) + "/" + str(len(train_data.fileindices)) + " : Minibatch Reward Weighted Multisample CE Loss= {:.6f}" .format(ce_loss_val) + " : Minibatch training accuracy= {:.6f}".format( acc_val)) # Run optimizer: optimize policy network sess.run( [model.train_op_policynet_expreward], feed_dict={ model.document_placeholder: batch_docs, model.predicted_multisample_label_placeholder: batch_oracle_multiple, model.actual_reward_multisample_placeholder: batch_reward_multiple, model.weight_placeholder: batch_weight }) # Increase step step += 1 # if step == 20: # break # Save Model print("MRT: Epoch " + str(epoch) + " : Saving model after epoch completion") checkpoint_path = os.path.join( FLAGS.train_dir, "model.ckpt.epoch-" + str(epoch)) model.saver.save(sess, checkpoint_path) # Backup Rouge Dict print("MRT: Epoch " + str(epoch) + " : Saving rouge dictionary") rouge_generator.save_rouge_dict() # Performance on the validation set print("MRT: Epoch " + str(epoch) + " : Performance on the validation data") # Get Predictions: Prohibit the use of gold labels validation_logits, validation_labels, validation_weights = batch_predict_with_a_model( validation_data, model, session=sess) # Validation Accuracy and Prediction validation_acc, validation_sum = sess.run( [model.final_accuracy, model.vaccuracy_summary], feed_dict={ model.logits_placeholder: validation_logits.eval(session=sess), model.label_placeholder: validation_labels.eval(session=sess), model.weight_placeholder: validation_weights.eval(session=sess) }) # Print Validation Summary model.summary_writer.add_summary( validation_sum, (epoch * len(train_data.fileindices))) print("MRT: Epoch " + str(epoch) + " : Validation (" + str(len(validation_data.fileindices)) + ") accuracy= {:.6f}".format(validation_acc)) # Writing validation predictions and final summaries print("MRT: Epoch " + str(epoch) + " : Writing final validation summaries") validation_data.write_prediction_summaries( validation_logits, "model.ckpt.epoch-" + str(epoch), session=sess) # Extimate Rouge Scores rouge_score = rouge_generator.get_full_rouge( FLAGS.train_dir + "/model.ckpt.epoch-" + str(epoch) + ".validation-summary-topranked", "validation") print("MRT: Epoch " + str(epoch) + " : Validation (" + str(len(validation_data.fileindices)) + ") rouge= {:.6f}".format(rouge_score)) # break print("Optimization Finished!")
def train_debug(sess): """ Training Mode: Create a new model and train the network """ tf.set_random_seed(seed) ### Prepare data for training print("Prepare vocab dict and read pretrained word embeddings ...") vocab_dict, word_embedding_array = DataProcessor( ).prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array print("Prepare training data ...") train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training") print("Training size: ", len(train_data.fileindices)) print("Prepare validation data ...") # data in whole batch with padded matrixes val_batch = batch_load_data(DataProcessor().prepare_news_data( vocab_dict, data_type="validation")) print("Validation size: ", val_batch.docs.shape[0]) fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1 if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0: q = int(FLAGS.sentembed_size // fil_lens_to_test) FLAGS.sentembed_size = q * fil_lens_to_test print("corrected embedding size: %d" % FLAGS.sentembed_size) # Create Model with various operations model = MY_Model(sess, len(vocab_dict) - 2) init_epoch = 1 # Resume training if indicated Select the model if FLAGS.model_to_load != -1: selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str( FLAGS.model_to_load) init_epoch = FLAGS.model_to_load + 1 print("Reading model parameters from %s" % selected_modelpath) model.saver.restore(sess, selected_modelpath) print("Model loaded.") # Initialize word embedding before training print( "Initialize word embedding vocabulary with pretrained embeddings ...") sess.run(model.vocab_embed_variable.assign(word_embedding_array)) ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training counter = 0 max_val_acc = -1 for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1): ep_time = time.time() # to check duration train_data.shuffle_fileindices() # Start Batch Training step = 1 total_ce_loss = 0 total_train_acc = 0 while (step * FLAGS.batch_size) <= len(train_data.fileindices): # Get batch data as Numpy Arrays batch = train_data.get_batch(((step - 1) * FLAGS.batch_size), (step * FLAGS.batch_size)) # Run optimizer: optimize policy and reward estimator sess.run( [model.train_op_policynet_withgold], feed_dict={ model.document_placeholder: batch.docs, model.label_placeholder: batch.labels, model.weight_placeholder: batch.weights }) prev_use_dpt = FLAGS.use_dropout FLAGS.use_dropout = False batch_logits, ce_loss, merged_summ = sess.run( [model.logits, model.cross_entropy_loss, model.merged], feed_dict={ model.document_placeholder: batch.docs, model.label_placeholder: batch.labels, model.weight_placeholder: batch.weights }) total_ce_loss += ce_loss probs = sess.run( model.predictions, feed_dict={model.logits_placeholder: batch_logits}) training_acc = accuracy_qas_top(probs, batch.labels, batch.weights, batch.isf_score_ids) FLAGS.use_dropout = prev_use_dpt total_train_acc += training_acc # Print the progress if (step % FLAGS.training_checkpoint) == 0: total_train_acc /= FLAGS.training_checkpoint acc_sum = sess.run( model.tstepa_accuracy_summary, feed_dict={model.train_acc_placeholder: total_train_acc}) total_ce_loss /= FLAGS.training_checkpoint # Print Summary to Tensor Board model.summary_writer.add_summary(merged_summ, counter) model.summary_writer.add_summary(acc_sum, counter) # Performance on the validation set # Get Predictions: Prohibit the use of gold labels FLAGS.authorise_gold_label = False prev_use_dpt = FLAGS.use_dropout FLAGS.use_dropout = False val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess) FLAGS.use_dropout = prev_use_dpt FLAGS.authorise_gold_label = True # Validation Accuracy and Prediction probs = sess.run( model.predictions, feed_dict={model.logits_placeholder: val_batch.logits}) validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids) ce_loss_val, ce_loss_sum, acc_sum = sess.run( [ model.cross_entropy_loss_val, model.ce_loss_summary_val, model.vstepa_accuracy_summary ], feed_dict={ model.logits_placeholder: val_batch.logits, model.label_placeholder: val_batch.labels, model.weight_placeholder: val_batch.weights, model.val_acc_placeholder: validation_acc }) # Print Validation Summary model.summary_writer.add_summary(acc_sum, counter) model.summary_writer.add_summary(ce_loss_sum, counter) print( "Epoch %2d, step: %2d(%2d) || CE loss || Train : %4.3f , Val : %4.3f ||| ACC || Train : %.3f , Val : %.3f" % (epoch, step, counter, total_ce_loss, ce_loss_val, training_acc, validation_acc)) total_ce_loss = 0 total_train_acc = 0 if (step % 5) == 0: # to have comparable tensorboard plots counter += 1 # Increase step step += 1 #END-WHILE-TRAINING ... but wait there is more ## eval metrics FLAGS.authorise_gold_label = False prev_use_dpt = FLAGS.use_dropout FLAGS.use_dropout = False val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess) FLAGS.use_dropout = prev_use_dpt FLAGS.authorise_gold_label = True # Validation metrics probs = sess.run( model.predictions, feed_dict={model.logits_placeholder: val_batch.logits}) acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids) mrr = mrr_metric(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") _map = map_score(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") print("Metrics: acc: %.4f | mrr: %.4f | map: %.4f" % (acc, mrr, _map)) print("Epoch %2d : Duration: %.4f" % (epoch, time.time() - ep_time)) if FLAGS.save_models: print("Saving model after epoch completion") checkpoint_path = os.path.join( FLAGS.train_dir, "step-a.model.ckpt.epoch-" + str(epoch)) model.saver.save(sess, checkpoint_path) print( "------------------------------------------------------------------------------------------" ) #END-FOR-EPOCH print("Optimization Finished!")
def train(sess): """ Training Mode: Create a new model and train the network """ ### Prepare data for training print("Prepare vocab dict and read pretrained word embeddings ...") vocab_dict, word_embedding_array = DataProcessor( ).prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array print("Prepare training data ...") train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training") print("Prepare validation data ...") # data in whole batch with padded matrixes val_batch = batch_load_data(DataProcessor().prepare_news_data( vocab_dict, data_type="validation")) fil_lens_to_test = FLAGS.max_filter_length - FLAGS.min_filter_length + 1 if FLAGS.handle_filter_output == "concat" and FLAGS.sentembed_size % fil_lens_to_test != 0: q = int(FLAGS.sentembed_size // fil_lens_to_test) FLAGS.sentembed_size = q * fil_lens_to_test print("corrected embedding size: %d" % FLAGS.sentembed_size) # Create Model with various operations model = MY_Model(sess, len(vocab_dict) - 2) init_epoch = 1 # Resume training if indicated Select the model if FLAGS.model_to_load != -1: if (os.path.isfile(FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str(FLAGS.model_to_load))): selected_modelpath = FLAGS.train_dir + "/step-a.model.ckpt.epoch-" + str( FLAGS.model_to_load) else: print("Model not found in checkpoint folder.") exit(0) # Reload saved model and test init_epoch = FLAGS.model_to_load + 1 print("Reading model parameters from %s" % selected_modelpath) model.saver.restore(sess, selected_modelpath) print("Model loaded.") # Initialize word embedding before training print( "Initialize word embedding vocabulary with pretrained embeddings ...") sess.run(model.vocab_embed_variable.assign(word_embedding_array)) ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy + 1): ep_time = time.time() # to check duration print("STEP A: Epoch " + str(epoch) + " : Start pretraining with supervised labels") print("STEP A: Epoch " + str(epoch) + " : Reshuffle training document indices") train_data.shuffle_fileindices() # Start Batch Training step = 1 total_ce_loss = 0 counter = 0 while (step * FLAGS.batch_size) <= len(train_data.fileindices): # Get batch data as Numpy Arrays batch = train_data.get_batch(((step - 1) * FLAGS.batch_size), (step * FLAGS.batch_size)) # Run optimizer: optimize policy and reward estimator _, batch_logits, ce_loss, merged_summ = sess.run( [ model.train_op_policynet_withgold, model.logits, model.cross_entropy_loss, model.merged ], feed_dict={ model.document_placeholder: batch.docs, model.label_placeholder: batch.labels, model.weight_placeholder: batch.weights }) total_ce_loss += ce_loss # Print the progress if (step % FLAGS.training_checkpoint) == 0: probs = sess.run( model.predictions, feed_dict={model.logits_placeholder: batch_logits}) training_acc = accuracy_qas_top(probs, batch.labels, batch.weights) acc_sum = sess.run( model.tstepa_accuracy_summary, feed_dict={model.train_acc_placeholder: training_acc}) total_ce_loss /= FLAGS.training_checkpoint print("STEP A: Epoch " + str(epoch) + " : Covered " + str(step * FLAGS.batch_size) + "/" + str(len(train_data.fileindices)) + " : Minibatch CE Loss= {:.6f}".format(total_ce_loss) + ", Minibatch Accuracy= {:.6f}".format(training_acc)) total_ce_loss = 0 # Print Summary to Tensor Board model.summary_writer.add_summary( merged_summ, (epoch - 1) * len(train_data.fileindices) + step * FLAGS.batch_size) model.summary_writer.add_summary( acc_sum, (epoch - 1) * len(train_data.fileindices) + step * FLAGS.batch_size) # Increase step step += 1 # if step == 100: # break #END-WHILE-TRAINING # Save Model print("STEP A: Epoch " + str(epoch) + " : Saving model after epoch completion") checkpoint_path = os.path.join(FLAGS.train_dir, "step-a.model.ckpt.epoch-" + str(epoch)) model.saver.save(sess, checkpoint_path) # Performance on the validation set print("STEP A: Epoch " + str(epoch) + " : Performance on the validation data") # Get Predictions: Prohibit the use of gold labels FLAGS.authorise_gold_label = False FLAGS.use_dropout = False val_batch = batch_predict_with_a_model(val_batch, "validation", model, session=sess) FLAGS.use_dropout = True # Validation Accuracy and Prediction probs = sess.run( model.predictions, feed_dict={model.logits_placeholder: val_batch.logits}) validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids) ce_loss_val, ce_loss_sum, acc_sum = sess.run( [ model.cross_entropy_loss_val, model.ce_loss_summary_val, model.vstepa_accuracy_summary ], feed_dict={ model.logits_placeholder: val_batch.logits, model.label_placeholder: val_batch.labels, model.weight_placeholder: val_batch.weights, model.val_acc_placeholder: validation_acc }) # Print Validation Summary model.summary_writer.add_summary(acc_sum, epoch * len(train_data.fileindices)) model.summary_writer.add_summary(ce_loss_sum, epoch * len(train_data.fileindices)) print("STEP A: Epoch %s : Validation (%s) CE loss = %.6f" % (str(epoch), str(val_batch.docs.shape[0]), ce_loss_val)) print("STEP A: Epoch %s : Validation (%s) Accuracy= %.6f" % (str(epoch), str(val_batch.docs.shape[0]), validation_acc)) # Estimate MRR on validation set mrr_score = mrr_metric(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") print("STEP A: Epoch %s : Validation (%s) MRR= %.6f" % (str(epoch), str(val_batch.docs.shape[0]), mrr_score)) # Estimate MAP score on validation set mapsc = map_score(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids, "validation") print("STEP A: Epoch %s : Validation (%s) MAP= %.6f" % (str(epoch), str(val_batch.docs.shape[0]), mapsc)) fn = "%s/step-a.model.ckpt.validation-metrics" % (FLAGS.train_dir) save_metrics(fn, FLAGS.load_prediction, validation_acc, mrr_score, mapsc) print("STEP A: Epoch %d : Duration: %.4f" % (epoch, time.time() - ep_time)) #END-FOR-EPOCH print("Optimization Finished!")
best_wer = 200 best_cer = 200 best_ep = 0 best_ep_wer = 0 with tf.Graph().as_default() and tf.device('/gpu:'+FLAGS.gpu_id): config = tf.ConfigProto(allow_soft_placement = True) tf.set_random_seed(seed) with tf.Session(config = config) as sess: model = MY_Model(sess, len(vocab_dict)) init_epoch = 1 for epoch in range(init_epoch, FLAGS.train_epoch+1): ep_time = time.time() # to check duration shuffle = (epoch != init_epoch) if FLAGS.do_sortagrad else True if shuffle: train_data.shuffle_fileindices() total_loss = 0 # Start Batch Training step = 1 while (step * FLAGS.batch_size) <= len(train_data.fileindices): # Get batch data as Numpy Arrays batch = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size)) dshape = np.array([batch.spect.shape[0],max_fixed],dtype=np.int64) sparse_labels = tf.SparseTensorValue(batch.label_indices,batch.label_values,dshape) # Run optimizer: optimize policy and reward estimator _,ctc_loss = sess.run([model.train_main_network, model.ctc_loss], feed_dict={model.spect_placeholder: batch.spect, model.label_placeholder: sparse_labels,
def train_debug(sess): """ Training Mode: Create a new model and train the network """ ### Prepare data for training print("Prepare vocab dict and read pretrained word embeddings ...") vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array print("Prepare training data ...") train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training") print("Training size: ",len(train_data.fileindices)) print("Prepare validation data ...") # data in whole batch with padded matrixes val_batch = batch_load_data(DataProcessor().prepare_news_data(vocab_dict, data_type="validation", normalizer=train_data.normalizer, pca_model=train_data.pca_model)) print("Validation size: ",val_batch.docs.shape[0]) #print("Prepare ROUGE reward generator ...") #rouge_generator = Reward_Generator() # Create Model with various operations model = MY_Model(sess, len(vocab_dict)-2) init_epoch = 1 # Resume training if indicated Select the model # Initialize word embedding before training print("Initialize word embedding vocabulary with pretrained embeddings ...") sess.run(model.vocab_embed_variable.assign(word_embedding_array)) ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training counter = 0 for epoch in range(init_epoch, FLAGS.train_epoch_crossentropy+1): ep_time = time.time() # to check duration train_data.shuffle_fileindices() # Start Batch Training step = 1 total_ce_loss = 0 total_train_acc = 0 while (step * FLAGS.batch_size) <= len(train_data.fileindices): # Get batch data as Numpy Arrays batch = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size)) # Run optimizer: optimize policy and reward estimator _,batch_logits,ce_loss,merged_summ = sess.run([ model.train_op_policynet_withgold, model.logits, model.cross_entropy_loss, model.merged], feed_dict={model.document_placeholder: batch.docs, model.label_placeholder: batch.labels, model.weight_placeholder: batch.weights, #model.cnt_placeholder: batch.cnt_score, model.isf_placeholder: batch.isf_score, model.idf_placeholder: batch.idf_score, model.locisf_score_placeholder: batch.locisf_score #model.sent_len_placeholder: batch.sent_lens }) total_ce_loss += ce_loss probs = sess.run(model.predictions,feed_dict={model.logits_placeholder: batch_logits}) training_acc = accuracy_qas_top( probs, batch.labels, batch.weights, batch.isf_score_ids) total_train_acc += training_acc # Print the progress if (step % FLAGS.training_checkpoint) == 0: total_train_acc /= FLAGS.training_checkpoint acc_sum = sess.run( model.tstepa_accuracy_summary, feed_dict={model.train_acc_placeholder: training_acc}) total_ce_loss /= FLAGS.training_checkpoint # Print Summary to Tensor Board model.summary_writer.add_summary(merged_summ, counter) model.summary_writer.add_summary(acc_sum, counter) # Performance on the validation set # Get Predictions: Prohibit the use of gold labels FLAGS.authorise_gold_label = False FLAGS.use_dropout = False val_batch = batch_predict_with_a_model(val_batch,"validation", model, session=sess) FLAGS.authorise_gold_label = True FLAGS.use_dropout = True # Validation Accuracy and Prediction probs = sess.run(model.probs,feed_dict={model.logits_placeholder: val_batch.logits}) validation_acc = accuracy_qas_top(probs, val_batch.labels, val_batch.weights, val_batch.isf_score_ids) ce_loss_val, ce_loss_sum, acc_sum = sess.run([ model.cross_entropy_loss_val, model.ce_loss_summary_val, model.vstepa_accuracy_summary], feed_dict={model.logits_placeholder: val_batch.logits, model.label_placeholder: val_batch.labels, model.weight_placeholder: val_batch.weights, model.isf_score_placeholder: val_batch.isf_score, model.locisf_score_placeholder: val_batch.locisf_score, model.val_acc_placeholder: validation_acc}) # Print Validation Summary model.summary_writer.add_summary(acc_sum, counter) model.summary_writer.add_summary(ce_loss_sum, counter) print("Epoch %2d, step: %2d(%2d) || CE loss || Train : %4.3f , Val : %4.3f ||| ACC || Train : %.3f , Val : %.3f" % (epoch,step,counter,total_ce_loss,ce_loss_val,total_train_acc,validation_acc)) total_ce_loss = 0 total_train_acc = 0 if (step % 5) == 0: # to have comparable tensorboard plots counter += 1 # Increase step step += 1 #END-WHILE-TRAINING ... but wait there is more print("Epoch %2d : Duration: %.4f" % (epoch,time.time()-ep_time) ) if not FLAGS.use_subsampled_dataset: print("Saving model after epoch completion") checkpoint_path = os.path.join(FLAGS.train_dir, "step-a.model.ckpt.epoch-"+str(epoch)) model.saver.save(sess, checkpoint_path) #END-FOR-EPOCH print("Optimization Finished!")
def train(): """ Training Mode: Create a new model and train the network """ # Training: use the tf default graph with tf.Graph().as_default() and tf.device('/gpu:2'): config = tf.ConfigProto(allow_soft_placement = True) # Start a session with tf.Session(config = config) as sess: ### Prepare data for training print("Prepare vocab dict and read pretrained word embeddings ...") vocab_dict, word_embedding_array = DataProcessor().prepare_vocab_embeddingdict() # vocab_dict contains _PAD and _UNK but not word_embedding_array print("Prepare training data ...") train_data = DataProcessor().prepare_news_data(vocab_dict, data_type="training") print("Prepare validation data ...") validation_data = DataProcessor().prepare_news_data(vocab_dict, data_type="validation") print("Prepare ROUGE reward generator ...") rouge_generator = Reward_Generator() # Create Model with various operations model = MY_Model(sess, len(vocab_dict)-2) # model = MY_Model(sess, 100) # Initialize word embedding before training print("Initialize word embedding vocabulary with pretrained embeddings ...") sess.run(model.vocab_embed_variable.assign(word_embedding_array)) ### STEP A : Start Pretraining the policy with Supervised Labels: Simple Cross Entropy Training validation_epochvsrougescores = [] for epoch in range(1, FLAGS.train_epoch_crossentropy+1): print("STEP A: Epoch "+str(epoch)+" : Start pretraining with supervised labels") print("STEP A: Epoch "+str(epoch)+" : Reshuffle training document indices") train_data.shuffle_fileindices() # Start Batch Training step = 1 while (step * FLAGS.batch_size) <= len(train_data.fileindices): # Get batch data as Numpy Arrays batch_docnames, batch_docs, batch_label, batch_weight = train_data.get_batch(((step-1)*FLAGS.batch_size), (step * FLAGS.batch_size)) # print(batch_docnames,batch_label) # Run optimizer: optimize policy and reward estimator sess.run([model.train_op_policynet_withgold], feed_dict={model.document_placeholder: batch_docs, model.label_placeholder: batch_label, model.weight_placeholder: batch_weight}) # Print the progress if (step % FLAGS.training_checkpoint) == 0: ce_loss_val, acc_val, ce_loss_sum, acc_sum = sess.run([model.cross_entropy_loss, model.accuracy, model.ce_loss_summary, model.tstepa_accuracy_summary], feed_dict={model.document_placeholder: batch_docs, model.label_placeholder: batch_label, model.weight_placeholder: batch_weight}) print("STEP A: Epoch "+str(epoch)+" : Covered " + str(step*FLAGS.batch_size)+"/"+str(len(train_data.fileindices))+ " : Minibatch CE Loss= {:.6f}".format(ce_loss_val) + ", Minibatch Accuracy= {:.6f}".format(acc_val)) # Print Summary to Tensor Board model.summary_writer.add_summary(ce_loss_sum, (epoch-1)*len(train_data.fileindices)+step*FLAGS.batch_size) model.summary_writer.add_summary(acc_sum, (epoch-1)*len(train_data.fileindices)+step*FLAGS.batch_size) # Increase step step += 1 # if step == 100: # break # Save Model print("STEP A: Epoch "+str(epoch)+" : Saving model after epoch completion") checkpoint_path = os.path.join(FLAGS.train_dir, "step-a.model.ckpt.epoch-"+str(epoch)) model.saver.save(sess, checkpoint_path) # Performance on the validation set print("STEP A: Epoch "+str(epoch)+" : Performance on the validation data") # Get Predictions: Prohibit the use of gold labels FLAGS.authorise_gold_label = False validation_logits, validation_labels, validation_weights = batch_predict_with_a_model(validation_data, model, session=sess) FLAGS.authorise_gold_label = True # Validation Accuracy and Prediction validation_acc, validation_sum = sess.run([model.final_accuracy, model.vstepa_accuracy_summary], feed_dict={model.logits_placeholder: validation_logits.eval(session=sess), model.label_placeholder: validation_labels.eval(session=sess), model.weight_placeholder: validation_weights.eval(session=sess)}) # Print Validation Summary model.summary_writer.add_summary(validation_sum, epoch*len(train_data.fileindices)) print("STEP A: Epoch "+str(epoch)+" : Validation ("+str(len(validation_data.fileindices))+") accuracy= {:.6f}".format(validation_acc)) # Writing validation predictions and final summaries print("STEP A: Epoch "+str(epoch)+" : Writing final validation summaries") validation_data.write_prediction_summaries(validation_logits, "step-a.model.ckpt.epoch-"+str(epoch), session=sess) # Extimate Rouge Scores rouge_score = rouge_generator.get_full_rouge(FLAGS.train_dir+"/step-a.model.ckpt.epoch-"+str(epoch)+".validation-summary-rankedtop3", "validation") print("STEP A: Epoch "+str(epoch)+" : Validation ("+str(len(validation_data.fileindices))+") rouge= {:.6f}".format(rouge_score)) # Store validation rouge scores validation_epochvsrougescores.append([rouge_score, epoch]) # break print(validation_epochvsrougescores) print("Optimization Finished!")