class Predict: def __init__(self): self.args = None self.textData = None self.model = None self.outFile = None self.sess = None self.saver = None self.model_name = None self.model_path = None self.globalStep = 0 self.summaryDir = None self.summaryWriter = None self.mergedSummary = None @staticmethod def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument('--resultDir', type=str, default='result', help='result directory') # data location dataArgs = parser.add_argument_group('Dataset options') dataArgs.add_argument('--summaryDir', type=str, default='./summaries') dataArgs.add_argument('--dataDir', type=str, default='data', help='dataset directory, save pkl here') dataArgs.add_argument('--datasetName', type=str, default='dataset', help='a TextData object') dataArgs.add_argument('--trainFile', type=str, default='sentences.train') # use val file for generation task dataArgs.add_argument('--valFile', type=str, default='sentences.continuation') dataArgs.add_argument('--testFile', type=str, default='sentences.eval') dataArgs.add_argument('--embedFile', type=str, default='wordembeddings-dim100.word2vec') dataArgs.add_argument('--doTest', action='store_true') dataArgs.add_argument('--vocabSize', type=int, default=20000, help='vocab size, use the most frequent words') # neural network options nnArgs = parser.add_argument_group('Network options') nnArgs.add_argument('--embeddingSize', type=int, default=100) nnArgs.add_argument('--hiddenSize', type=int, default=512, help='hiddenSize for RNN sentence encoder') nnArgs.add_argument('--oriSize', type=int, default=512) nnArgs.add_argument('--rnnLayers', type=int, default=1) nnArgs.add_argument('--maxSteps', type=int, default=30) nnArgs.add_argument('--project', action='store_true') # training options trainingArgs = parser.add_argument_group('Training options') trainingArgs.add_argument('--modelPath', type=str, default='saved') trainingArgs.add_argument('--preEmbedding', action='store_true') trainingArgs.add_argument('--dropOut', type=float, default=0.8, help='dropout rate for RNN (keep prob)') trainingArgs.add_argument('--learningRate', type=float, default=0.001, help='learning rate') trainingArgs.add_argument('--batchSize', type=int, default=64, help='batch size') # max_grad_norm trainingArgs.add_argument('--maxGradNorm', type=int, default=5) ## do not add dropOut in the test mode! trainingArgs.add_argument('--test', action='store_true', help='if in test mode') trainingArgs.add_argument('--epochs', type=int, default=100, help='most training epochs') trainingArgs.add_argument('--device', type=str, default='/gpu:0', help='use the first GPU as default') trainingArgs.add_argument('--loadModel', action='store_true', help='whether or not to use old models') trainingArgs.add_argument('--testModel', action='store_true', help='do not train, only test') trainingArgs.add_argument('--generate', action='store_true', help='for task 2, generate sentences greedily') # note: we can set this number larger than 20, then we truncated it when handing in trainingArgs.add_argument('--maxGenerateLength', type=int, default=25, help='maximum length when generating sentences') trainingArgs.add_argument('--writePerplexity', action='store_true') return parser.parse_args(args) def main(self, args=None): print('Tensorflow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.outFile = utils.constructFileName(self.args, prefix=self.args.resultDir) self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataDir) datasetFileName = os.path.join(self.args.dataDir, self.args.datasetName) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}'.format(datasetFileName)) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.model_path = utils.constructFileName(self.args, prefix=self.args.modelPath, tag='model') self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.args.summaryDir) with tf.device(self.args.device): if not self.args.generate: self.model = Model(self.args, self.textData) else: print('Creating model for generation') self.model = ModelG(self.args, self.textData) params = tf.trainable_variables() print('Model created') # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() # initialize all global variables self.sess.run(init) print('All variables initialized') if not self.args.testModel and not self.args.generate and not self.args.writePerplexity: self.train(self.sess) elif self.args.writePerplexity: self.test(self.sess, tag='test') elif self.args.generate: self.generate(self.sess) else: self.test_model() def generate(self, sess): print('generating!') batches = self.textData.get_batches(tag='val') all_sents = [] for nextBatch in tqdm(batches): # set dummy initial values for predictions predictions = np.zeros(shape=(nextBatch.batch_size), dtype=np.int32) sents = [] for i in range(nextBatch.batch_size): sents.append([self.textData.BOS_WORD]) for time_step in range(self.args.maxGenerateLength): ops, feed_dict, sents = self.model.step(nextBatch, predictions=predictions, time_step=time_step, sents=sents) predictions = sess.run(ops, feed_dict) all_sents.extend(sents) with open('generated.txt', 'w') as file: for sent in all_sents: sentence = ' '.join(sent) file.write(sentence+'\n') def train(self, sess): print('Start training') out = open(self.outFile, 'w', 1) out.write(self.outFile + '\n') utils.writeInfo(out, self.args) current_val_loss = np.inf for e in range(self.args.epochs): # training trainBatches = self.textData.get_batches(tag='train') totalTrainLoss = 0.0 # cnt of batches cnt = 0 total_steps = 0 for nextBatch in tqdm(trainBatches): cnt += 1 self.globalStep += 1 for sample in nextBatch.samples: total_steps += sample.length ops, feed_dict = self.model.step(nextBatch, test=False) _, loss, trainPerplexity = sess.run(ops, feed_dict) totalTrainLoss += loss # average across samples in this step trainPerplexity = np.mean(trainPerplexity) self.summaryWriter.add_summary(utils.makeSummary({"trainLoss": loss}), self.globalStep) self.summaryWriter.add_summary(utils.makeSummary({"trainPerplexity": trainPerplexity}), self.globalStep) # compute perplexity over all samples in an epoch trainPerplexity = np.exp(totalTrainLoss/total_steps) print('\nepoch = {}, Train, loss = {}, perplexity = {}'. format(e, totalTrainLoss, trainPerplexity)) out.write('\nepoch = {}, loss = {}, perplexity = {}\n'. format(e, totalTrainLoss, trainPerplexity)) out.flush() valLoss, val_num = self.test(sess, tag='val') testLoss, test_num = self.test(sess, tag='test') valPerplexity = np.exp(valLoss/val_num) testPerplexity = np.exp(testLoss/test_num) print('Val, loss = {}, perplexity = {}'. format(valLoss, valPerplexity)) out.write('Val, loss = {}, perplexity = {}\n'. format(valLoss, valPerplexity)) print('Test, loss = {}, perplexity = {}'. format(testLoss, testPerplexity)) out.write('Test, loss = {}, perplexity = {}\n'. format(testLoss, testPerplexity)) # we do not use cross val currently, just train, then evaluate #if True: if valLoss < current_val_loss: current_val_loss = valLoss print('New val loss {} at epoch {}'.format(valLoss, e)) out.write('New val loss {} at epoch {}\n'.format(valLoss, e)) save_path = self.saver.save(sess, save_path=self.model_name) print('model saved at {}'.format(save_path)) out.write('model saved at {}\n'.format(save_path)) out.flush() out.close() def write_perplexity(self, batch, perplexity): assert len(batch.samples) == len(perplexity) input_ = [] for sample in batch.samples: sent = [] for word_id in sample.input_: word = self.textData.id2word[word_id] sent.append(word) if word == self.textData.EOS_WORD: break sent = ' '.join(sent).strip() input_.append(sent) with open('perplexity.txt', 'a') as file: for idx, sent in enumerate(input_): file.write(sent+'\t'+str(perplexity[idx])+'\n') def test(self, sess, tag = 'val'): if tag == 'val': print('Validating\n') batches = self.textData.val_batches else: print('Testing\n') batches = self.textData.test_batches cnt = 0 total_loss = 0.0 total_steps = 0 for idx, nextBatch in enumerate(tqdm(batches)): cnt += 1 ops, feed_dict = self.model.step(nextBatch, test=True) loss, perplexity = sess.run(ops, feed_dict) total_loss += loss for sample in nextBatch.samples: total_steps += sample.length if self.args.writePerplexity: self.write_perplexity(nextBatch, perplexity) return total_loss, total_steps def test_model(self): # TODO: placeholder, this function is useless in this implementation, ignore pass
def main(self, args=None): print('Tensorflow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.outFile = utils.constructFileName(self.args, prefix=self.args.resultDir) self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataDir) datasetFileName = os.path.join(self.args.dataDir, self.args.datasetName) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}'.format(datasetFileName)) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.model_path = utils.constructFileName(self.args, prefix=self.args.modelPath, tag='model') self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.args.summaryDir) with tf.device(self.args.device): if not self.args.generate: self.model = Model(self.args, self.textData) else: print('Creating model for generation') self.model = ModelG(self.args, self.textData) params = tf.trainable_variables() print('Model created') # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() # initialize all global variables self.sess.run(init) print('All variables initialized') if not self.args.testModel and not self.args.generate and not self.args.writePerplexity: self.train(self.sess) elif self.args.writePerplexity: self.test(self.sess, tag='test') elif self.args.generate: self.generate(self.sess) else: self.test_model()
class Predict: def __init__(self): self.args = None self.textData = None self.model = None self.outFile = None self.sess = None self.saver = None self.model_name = None self.model_path = None self.globalStep = 0 self.summaryDir = None self.testOutFile = None self.summaryWriter = None self.mergedSummary = None @staticmethod def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument('--resultDir', type=str, default='result', help='result directory') # data location dataArgs = parser.add_argument_group('Dataset options') dataArgs.add_argument('--random', type=int, default=3) dataArgs.add_argument('--backward', type=int, default=3) dataArgs.add_argument('--near', type=int, default=3) dataArgs.add_argument('--randomFile', type=str, default='random.csv') dataArgs.add_argument('--backwardFile', type=str, default='backward.csv') dataArgs.add_argument('--nearFile', type=str, default='near.csv') dataArgs.add_argument('--summaryDir', type=str, default='summaries') dataArgs.add_argument('--datasetName', type=str, default='dataset', help='a TextData object') dataArgs.add_argument('--dataDir', type=str, default='data', help='dataset directory, save pkl here') dataArgs.add_argument('--trainFile', type=str, default='train_dummy.csv') dataArgs.add_argument('--valFile', type=str, default='val.csv') dataArgs.add_argument('--testFile', type=str, default='test.csv') dataArgs.add_argument('--ethTest', type=str, default='eth_test.csv') dataArgs.add_argument('--embedFile', type=str, default='glove.840B.300d.txt') dataArgs.add_argument('--vocabSize', type=int, default=-1, help='vocab size, use the most frequent words') # neural network options nnArgs = parser.add_argument_group('Network options') nnArgs.add_argument('--embeddingSize', type=int, default=300) nnArgs.add_argument('--nSentences', type=int, default=6) nnArgs.add_argument('--hiddenSize', type=int, default=512, help='hiddenSize for RNN sentence encoder') nnArgs.add_argument('--attSize', type=int, default=512) nnArgs.add_argument('--sentenceAttSize', type=int, default=512) nnArgs.add_argument('--rnnLayers', type=int, default=1) nnArgs.add_argument('--maxSteps', type=int, default=30) nnArgs.add_argument('--numClasses', type=int, default=2) nnArgs.add_argument('--ffnnLayers', type=int, default=2) nnArgs.add_argument('--ffnnSize', type=int, default=300) nnArgs.add_argument('--pffnnLayers', type=int, default=2) nnArgs.add_argument('--pffnnSize', type=int, default=512) nnArgs.add_argument('--nn', type=str, default='att') # training options trainingArgs = parser.add_argument_group('Training options') trainingArgs.add_argument('--dataProcess', action='store_true') trainingArgs.add_argument('--modelPath', type=str, default='saved') trainingArgs.add_argument('--preEmbedding', action='store_true') trainingArgs.add_argument('--dropOut', type=float, default=0.8, help='dropout rate for RNN (keep prob)') trainingArgs.add_argument('--learningRate', type=float, default=0.001, help='learning rate') trainingArgs.add_argument('--batchSize', type=int, default=100, help='batch size') # max_grad_norm ## do not add dropOut in the test mode! trainingArgs.add_argument('--twitterTest', action='store_true', help='whether or not do test in twitter dataset') trainingArgs.add_argument('--epochs', type=int, default=200, help='most training epochs') trainingArgs.add_argument('--device', type=str, default='/gpu:0', help='use the first GPU as default') trainingArgs.add_argument('--loadModel', action='store_true', help='whether or not to use old models') trainingArgs.add_argument('--testModel', action='store_true') trainingArgs.add_argument('--testTag', type=str, default='test') return parser.parse_args(args) def main(self, args=None): print('Tensorflow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.args.resultDir = self.args.nn +'_' + self.args.resultDir self.args.modelPath = self.args.nn +'_' + self.args.modelPath self.args.summaryDir = self.args.nn +'_' + self.args.summaryDir if not os.path.exists(self.args.resultDir): os.makedirs(self.args.resultDir) if not os.path.exists(self.args.modelPath): os.makedirs(self.args.modelPath) if not os.path.exists(self.args.summaryDir): os.makedirs(self.args.summaryDir) self.outFile = utils.constructFileName(self.args, prefix=self.args.resultDir) self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataDir) datasetFileName = os.path.join(self.args.dataDir, self.args.datasetName) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}'.format(datasetFileName)) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) if self.args.dataProcess: exit(0) sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.model_path = utils.constructFileName(self.args, prefix=self.args.modelPath, tag='model') self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.args.summaryDir) with tf.device(self.args.device): if self.args.nn == 'vanilla': print('Creating vanilla model!') self.model = Model_vanilla(self.args, self.textData) elif self.args.nn == 'att': print('Creating model with sentences and words attention!') self.model = Model_att(self.args, self.textData) else: print('Creating model with only words attention!') self.model = Model_satt(self.args, self.textData) print('Model created') # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() # initialize all global variables self.sess.run(init) print('All variables initialized') if not self.args.testModel: self.train(self.sess) else: self.testModel(sess=self.sess, tag=self.args.testTag) def train(self, sess): print('Start training') out = open(self.outFile, 'w', 1) out.write(self.outFile + '\n') utils.writeInfo(out, self.args) current_valAcc = 0.0 for e in range(self.args.epochs): # training #trainBatches = self.textData.train_batches trainBatches = self.textData.get_batches(tag='train') totalTrainLoss = 0.0 # cnt of batches cnt = 0 total_samples = 0 total_corrects = 0 for nextBatch in tqdm(trainBatches): cnt += 1 self.globalStep += 1 total_samples += nextBatch.batch_size ops, feed_dict = self.model.step(nextBatch, test=False) _, loss, predictions, corrects = sess.run(ops, feed_dict) total_corrects += corrects totalTrainLoss += loss # average across samples in this step self.summaryWriter.add_summary(utils.makeSummary({"trainLoss": loss}), self.globalStep) # compute perplexity over all samples in an epoch trainAcc = total_corrects*1.0/total_samples print('\nepoch = {}, Train, loss = {}, trainAcc = {}'. format(e, totalTrainLoss, trainAcc)) out.write('\nepoch = {}, loss = {}, trainAcc = {}\n'. format(e, totalTrainLoss, trainAcc)) out.flush() valAcc, valLoss = self.test(sess, tag='val') print('Val, loss = {}, valAcc = {}'. format(valLoss, valAcc)) out.write('Val, loss = {}, valAcc = {}\n'. format(valLoss, valAcc)) testAcc, testLoss = self.test(sess, tag='test') print('Test, loss = {}, testAcc = {}'. format(testLoss, testAcc)) out.write('Test, loss = {}, testAcc = {}\n'. format(testLoss, testAcc)) out.flush() # we do not use cross val currently, just train, then evaluate if valAcc >= current_valAcc: current_valAcc = valAcc print('New valAcc {} at epoch {}'.format(valAcc, e)) out.write('New valAcc {} at epoch {}\n'.format(valAcc, e)) save_path = self.saver.save(sess, save_path=self.model_name) print('model saved at {}'.format(save_path)) out.write('model saved at {}\n'.format(save_path)) out.flush() out.close() def createETH(self): samples = self.textData._create_samples(os.path.join(self.args.dataDir, self.args.ethTest)) batches = self.textData._create_batch(samples) return batches def testModel(self, sess, tag='test'): if tag == 'test': print('Using original test set to test the performance') out_file_name = 'original_test_results.csv' batches = self.textData.test_batches else: print('Using ETH test set to test the performance') out_file_name = 'ETH_test_results.csv' batches = self.createETH() cnt = 0 total_samples = 0 total_corrects = 0 total_loss = 0.0 all_predictions = [] for idx, nextBatch in enumerate(tqdm(batches)): cnt += 1 total_samples += nextBatch.batch_size ops, feed_dict = self.model.step(nextBatch, test=True) loss, predictions, corrects = sess.run(ops, feed_dict) all_predictions.extend(predictions) total_loss += loss total_corrects += corrects with open(out_file_name, 'w') as file: for prediction in all_predictions: file.write(str(prediction) + '\n') acc = total_corrects*1.0/total_samples print(acc) print('Test Over!') def test(self, sess, tag = 'val'): if tag == 'val': print('Validating\n') batches = self.textData.val_batches else: print('Testing\n') batches = self.textData.test_batches cnt = 0 total_samples = 0 total_corrects = 0 total_loss = 0.0 all_predictions = [] for idx, nextBatch in enumerate(tqdm(batches)): cnt += 1 total_samples += nextBatch.batch_size ops, feed_dict = self.model.step(nextBatch, test=True) loss, predictions, corrects = sess.run(ops, feed_dict) all_predictions.extend(predictions) total_loss += loss total_corrects += corrects acc = total_corrects*1.0/total_samples return acc, total_loss
def main(self, args=None): print('TensorFlow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.resultDir = os.path.join(self.args.resultDir, self.args.dataset) self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset) self.dataDir = os.path.join(self.args.dataDir, self.args.dataset) self.outFile = utils.constructFileName(self.args, prefix=self.resultDir) self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataset, createDataSetName=True) datasetFileName = os.path.join(self.dataDir, self.args.datasetName) if not os.path.exists(self.resultDir): os.makedirs(self.resultDir) if not os.path.exists(self.args.modelPath): os.makedirs(self.args.modelPath) if not os.path.exists(self.summaryDir): os.makedirs(self.summaryDir) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}'.format(datasetFileName)) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) self.modelPath = os.path.join(self.args.modelPath, self.args.dataset) self.model_path = utils.constructFileName(self.args, prefix=self.modelPath, tag='model') if not os.path.exists(self.model_path): os.makedirs(self.model_path) self.model_name = os.path.join(self.model_path, 'model') # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.summaryDir) tf.enable_eager_execution() self.model = Model(self.args, self.textData) for e in range(self.args.epochs): # training trainBatches = self.textData.train_batches #trainBatches = self.textData.get_batches(tag='train') totalTrainLoss = 0.0 # cnt of batches cnt = 0 total_samples = 0 total_corrects = 0 all_skip_rate = [] for idx, nextBatch in enumerate(tqdm(trainBatches)): cnt += 1 #nextBatch = trainBatches[227] self.globalStep += 1 total_samples += nextBatch.batch_size self.model.step(nextBatch, test=False) _, loss, predictions, corrects, skip_rate, skip_flag = self.model.buildNetwork() loss = self.to_numpy(loss) predictions = self.to_numpy(predictions) corrects = self.to_numpy(corrects) skip_rate = self.to_numpy(skip_rate) skip_flag = self.to_numpy(skip_flag) # skip_rate: batch_size * n_samples all_skip_rate.extend(skip_rate.tolist()) #print(loss, idx) total_corrects += corrects totalTrainLoss += loss trainAcc = total_corrects * 1.0 / (total_samples*self.args.nSamples) train_skip_rate = np.average(all_skip_rate) print('\nepoch = {}, Train, loss = {}, trainAcc = {}, train_skip_rate = {}'. format(e, totalTrainLoss, trainAcc, train_skip_rate))
def main(self, args=None): print('TensorFlow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.resultDir = os.path.join(self.args.resultDir, self.args.dataset) self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset) self.dataDir = os.path.join(self.args.dataDir, self.args.dataset) self.testDir = os.path.join(self.args.testDir, self.args.dataset) self.outFile = utils.constructFileName(self.args, prefix=self.resultDir) self.testFile = utils.constructFileName(self.args, prefix=self.testDir) self.args.datasetName = utils.constructFileName( self.args, prefix=self.args.dataset, createDataSetName=True) datasetFileName = os.path.join(self.dataDir, self.args.datasetName) if not os.path.exists(self.resultDir): os.makedirs(self.resultDir) if not os.path.exists(self.testDir): os.makedirs(self.testDir) if not os.path.exists(self.args.modelPath): os.makedirs(self.args.modelPath) if not os.path.exists(self.summaryDir): os.makedirs(self.summaryDir) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}, exiting ...'.format( datasetFileName)) exit(0) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) # self.statistics() sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.model_path = os.path.join(self.args.modelPath, self.args.dataset) self.model_path = utils.constructFileName(self.args, prefix=self.model_path, tag='model') if not os.path.exists(self.model_path): os.makedirs(self.model_path) self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.summaryDir) if self.args.eager: tf.enable_eager_execution( config=sessConfig, device_policy=tf.contrib.eager.DEVICE_PLACEMENT_WARN) print('eager execution enabled') # import timeit # # start = timeit.default_timer() # # self.textData.get_batches(tag='train', augment=self.args.augment) # # stop = timeit.default_timer() # # print('Time: ', stop - start) # exit(0) with tf.device(self.args.device): if self.args.model.find('hbmp') != -1: if self.args.model.find('share') != -1: print('Creating model with HBMP share') self.model = ModelHBMPShare(self.args, self.textData, eager=self.args.eager) else: print('Creating model with HBMP ordinary') self.model = ModelHBMP(self.args, self.textData, eager=self.args.eager) else: self.model = ModelBasic(self.args, self.textData, eager=self.args.eager) print('Basic model created!') if self.args.eager: self.train_eager() exit(0) # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() table_init = tf.tables_initializer() # initialize all global variables self.sess.run([init, table_init]) print('All variables initialized') self.train(self.sess)
class Train: def __init__(self): self.args = None self.textData = None self.model = None self.outFile = None self.sess = None self.saver = None self.model_name = None self.model_path = None self.globalStep = 0 self.summaryDir = None self.testOutFile = None self.summaryWriter = None self.mergedSummary = None @staticmethod def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument('--resultDir', type=str, default='result', help='result directory') parser.add_argument('--testDir', type=str, default='test_result') # data location dataArgs = parser.add_argument_group('Dataset options') dataArgs.add_argument('--summaryDir', type=str, default='summaries') dataArgs.add_argument('--datasetName', type=str, default='dataset', help='a TextData object') dataArgs.add_argument('--dataDir', type=str, default='data', help='dataset directory, save pkl here') dataArgs.add_argument('--dataset', type=str, default='emo') dataArgs.add_argument('--trainFile', type=str, default='train.txt') dataArgs.add_argument('--valFile', type=str, default='dev.txt') dataArgs.add_argument('--testFile', type=str, default='test.txt') dataArgs.add_argument('--trainLiwcFile', type=str, default='train_liwc.csv') dataArgs.add_argument('--valLiwcFile', type=str, default='dev_liwc.csv') dataArgs.add_argument('--testLiwcFile', type=str, default='dev_liwc.csv') dataArgs.add_argument('--embeddingFile', type=str, default='glove.840B.300d.txt') dataArgs.add_argument('--vocabSize', type=int, default=-1, help='vocab size, use the most frequent words') dataArgs.add_argument('--snliDir', type=str, default='snli') dataArgs.add_argument('--trainSnliFile', type=str, default='train_snli.txt') dataArgs.add_argument('--valSnliFile', type=str, default='dev_snli.txt') dataArgs.add_argument('--testSnliFile', type=str, default='test_snli.txt') # neural network options nnArgs = parser.add_argument_group('Network options') nnArgs.add_argument('--embeddingSize', type=int, default=300) nnArgs.add_argument('--hiddenSize', type=int, default=300, help='hiddenSize for RNN sentence encoder') nnArgs.add_argument( '--rnnLayers', type=int, default=1, help='number of RNN layers, fix to 1 in the DCRNN model') nnArgs.add_argument('--maxSteps', type=int, default=30) nnArgs.add_argument('--emoClasses', type=int, default=4) nnArgs.add_argument('--snliClasses', type=int, default=2) nnArgs.add_argument('--nTurn', type=int, default=3) nnArgs.add_argument('--speakerEmbedSize', type=int, default=0) nnArgs.add_argument('--nLSTM', type=int, default=3, help='in DCRNN, this is the ') nnArgs.add_argument('--heads', type=int, default=3) nnArgs.add_argument('--attn', action='store_true') nnArgs.add_argument('--selfattn', action='store_true') nnArgs.add_argument('--nContexts', type=int, default=4) nnArgs.add_argument('--independent', action='store_true') # training options trainingArgs = parser.add_argument_group('Training options') trainingArgs.add_argument( '--model', type=str, help='hbmp, dcrnn, transfer+hbmp, transfer+dcrnn, hbmp+share') trainingArgs.add_argument('--eager', action='store_true', help='turn on eager mode for debugging') trainingArgs.add_argument('--modelPath', type=str, default='saved') trainingArgs.add_argument('--preEmbedding', action='store_true') trainingArgs.add_argument('--elmo', action='store_true') trainingArgs.add_argument('--trainElmo', action='store_true') trainingArgs.add_argument('--dropOut', type=float, default=1.0, help='dropout rate for RNN (keep prob)') trainingArgs.add_argument('--learningRate', type=float, default=0.001, help='learning rate') trainingArgs.add_argument('--batchSize', type=int, default=100, help='batch size') trainingArgs.add_argument('--epochs', type=int, default=200, help='most training epochs') trainingArgs.add_argument('--device', type=str, default='/gpu:0', help='use the first GPU as default') trainingArgs.add_argument('--loadModel', action='store_true', help='whether or not to use old models') trainingArgs.add_argument( '--sampleWeight', default=6.848, type=float, help='a constant to balance different categories') trainingArgs.add_argument( '--weighted', action='store_true', help='whether or not to weight the training samples') trainingArgs.add_argument('--ffnn', type=int, default=500, help='intermediate ffnn size') trainingArgs.add_argument( '--gamma', type=float, default=0.1, help='we use a lambda to balance between snli and emo,' 'multiply gradients snli samples by this lambda') trainingArgs.add_argument( '--universal', action='store_true', help='whether or not to use universal sent embeddings') trainingArgs.add_argument( '--augment', action='store_true', help='data augumentation by adding random 2nd sentences') """ in training data: happy, sad, angry: 5k (16.67%) each; others: 15k (50%) in dev/test: happy, sad, angry: 4% each; others: 88% 88/3 = 29.333 29.33/4 = 7.33 in genuine data: 30160 train, happy = 4243, of 0.14068302387267906, sad = 5463, of 0.1811339522546419, angry = 5506, of 0.18255968169761272, others = 14948, of 0.4956233421750663 2755 val, happy = 142, of 0.051542649727767696, sad = 125, of 0.045372050816696916, angry = 150, of 0.0544464609800363, others = 2338, of 0.8486388384754991 """ return parser.parse_args(args) def statistics(self): """ 27144 train, happy = 3815, of 0.14054671382257589, sad = 4920, of 0.1812555260831123, angry = 4977, of 0.1833554376657825, others = 13432, of 0.4948423224285293 3016 val, happy = 428, of 0.1419098143236074, sad = 543, of 0.18003978779840848, angry = 529, of 0.17539787798408488, others = 1516, of 0.5026525198938993 :return: """ train_samples = self.textData.train_samples val_samples = self.textData.valid_samples def cnt(samples, tag='train'): happy = 0 sad = 0 angry = 0 others = 0 for sample in samples: if sample.label == self.textData.label2idx['happy']: happy += 1 elif sample.label == self.textData.label2idx['sad']: sad += 1 elif sample.label == self.textData.label2idx['angry']: angry += 1 else: others += 1 total = happy + sad + angry + others print( '{} {}, happy = {}, of {}, sad = {}, of {}, angry = {}, of {}, others = {}, of {}' .format(total, tag, happy, happy / total, sad, sad / total, angry, angry / total, others, others / total)) cnt(train_samples, 'train') cnt(val_samples, 'val') exit(0) def main(self, args=None): print('TensorFlow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.resultDir = os.path.join(self.args.resultDir, self.args.dataset) self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset) self.dataDir = os.path.join(self.args.dataDir, self.args.dataset) self.testDir = os.path.join(self.args.testDir, self.args.dataset) self.outFile = utils.constructFileName(self.args, prefix=self.resultDir) self.testFile = utils.constructFileName(self.args, prefix=self.testDir) self.args.datasetName = utils.constructFileName( self.args, prefix=self.args.dataset, createDataSetName=True) datasetFileName = os.path.join(self.dataDir, self.args.datasetName) if not os.path.exists(self.resultDir): os.makedirs(self.resultDir) if not os.path.exists(self.testDir): os.makedirs(self.testDir) if not os.path.exists(self.args.modelPath): os.makedirs(self.args.modelPath) if not os.path.exists(self.summaryDir): os.makedirs(self.summaryDir) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}, exiting ...'.format( datasetFileName)) exit(0) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) # self.statistics() sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.model_path = os.path.join(self.args.modelPath, self.args.dataset) self.model_path = utils.constructFileName(self.args, prefix=self.model_path, tag='model') if not os.path.exists(self.model_path): os.makedirs(self.model_path) self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.summaryDir) if self.args.eager: tf.enable_eager_execution( config=sessConfig, device_policy=tf.contrib.eager.DEVICE_PLACEMENT_WARN) print('eager execution enabled') # import timeit # # start = timeit.default_timer() # # self.textData.get_batches(tag='train', augment=self.args.augment) # # stop = timeit.default_timer() # # print('Time: ', stop - start) # exit(0) with tf.device(self.args.device): if self.args.model.find('hbmp') != -1: if self.args.model.find('share') != -1: print('Creating model with HBMP share') self.model = ModelHBMPShare(self.args, self.textData, eager=self.args.eager) else: print('Creating model with HBMP ordinary') self.model = ModelHBMP(self.args, self.textData, eager=self.args.eager) else: self.model = ModelBasic(self.args, self.textData, eager=self.args.eager) print('Basic model created!') if self.args.eager: self.train_eager() exit(0) # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() table_init = tf.tables_initializer() # initialize all global variables self.sess.run([init, table_init]) print('All variables initialized') self.train(self.sess) def train_eager(self): for e in range(self.args.epochs): trainBatches = self.textData.train_batches for idx, nextBatch in enumerate(tqdm(trainBatches)): self.model.step(nextBatch, test=False, eager=self.args.eager) self.model.buildNetwork() print() def train(self, sess): #sess = tf_debug.LocalCLIDebugWrapperSession(sess) print('Start training') out = open(self.outFile, 'w', 1) out.write(self.outFile + '\n') utils.writeInfo(out, self.args) current_val_f1_micro_unweighted = 0.0 for e in range(self.args.epochs): # training trainBatches = self.textData.get_batches(tag='train', augment=self.args.augment) #trainBatches = self.textData.train_batches totalTrainLoss = 0.0 # cnt of batches cnt = 0 total_samples = 0 total_corrects = 0 all_predictions = [] all_labels = [] all_sample_weights = [] for idx, nextBatch in enumerate(tqdm(trainBatches)): cnt += 1 self.globalStep += 1 total_samples += nextBatch.batch_size # print(idx) ops, feed_dict, labels, sample_weights = self.model.step( nextBatch, test=False) _, loss, predictions, corrects = sess.run(ops, feed_dict) all_predictions.extend(predictions) all_labels.extend(labels) all_sample_weights.extend(sample_weights) total_corrects += corrects totalTrainLoss += loss self.summaryWriter.add_summary( utils.makeSummary({"train_loss": loss}), self.globalStep) #break trainAcc = total_corrects * 1.0 / total_samples # calculate f1 score for train (weighted/unweighted) train_f1_micro, train_f1_macro, train_p_micro, train_r_micro, train_p_macro, train_r_macro\ = self.cal_F1(y_pred=all_predictions, y_true=all_labels) train_f1_micro_w, train_f1_macro_w, train_p_micro_w, train_r_micro_w, train_p_macro_w, train_r_macro_w\ = self.cal_F1(y_pred=all_predictions, y_true=all_labels, sample_weight=all_sample_weights) print( '\nepoch = {}, Train, loss = {}, trainAcc = {}, train_f1_micro = {}, train_f1_macro = {},' ' train_f1_micro_w = {}, train_f1_macro_w = {}'.format( e, totalTrainLoss, trainAcc, train_f1_micro, train_f1_macro, train_f1_micro_w, train_f1_macro_w)) print( '\ttrain_p_micro = {}, train_r_micro = {}, train_p_macro = {}, train_r_macro = {}' .format(train_p_micro, train_r_micro, train_p_macro, train_r_macro)) print( '\ttrain_p_micro_w = {}, train_r_micro_w = {}, train_p_macro_w = {}, train_r_macro_w = {}' .format(train_p_micro_w, train_r_micro_w, train_p_macro_w, train_r_macro_w)) out.write( '\nepoch = {}, loss = {}, trainAcc = {}, train_f1_micro = {}, train_f1_macro = {},' ' train_f1_micro_w = {}, train_f1_macro_w = {}\n'.format( e, totalTrainLoss, trainAcc, train_f1_micro, train_f1_macro, train_f1_micro_w, train_f1_macro_w)) out.write( '\ttrain_p_micro = {}, train_r_micro = {}, train_p_macro = {}, train_r_macro = {}\n' .format(train_p_micro, train_r_micro, train_p_macro, train_r_macro)) out.write( '\ttrain_p_micro_w = {}, train_r_micro_w = {}, train_p_macro_w = {}, train_r_macro_w = {}\n' .format(train_p_micro_w, train_r_micro_w, train_p_macro_w, train_r_macro_w)) #continue out.flush() # calculate f1 score for val (weighted/unweighted) valAcc, valLoss, val_f1_micro, val_f1_macro, val_f1_micro_w, val_f1_macro_w,\ val_p_micro, val_r_micro, val_p_macro, val_r_macro, val_p_micro_w, val_r_micro_w, val_p_macro_w, val_r_macro_w\ = self.test(sess, tag='val') print( '\n\tVal, loss = {}, valAcc = {}, val_f1_micro = {}, val_f1_macro = {}, val_f1_micro_w = {}, val_f1_macro_w = {}' .format(valLoss, valAcc, val_f1_micro, val_f1_macro, val_f1_micro_w, val_f1_macro_w)) print( '\t\t val_p_micro = {}, val_r_micro = {}, val_p_macro = {}, val_r_macro = {}' .format(val_p_micro, val_r_micro, val_p_macro, val_r_macro)) print( '\t\t val_p_micro_w = {}, val_r_micro_w = {}, val_p_macro_w = {}, val_r_macro_w = {}' .format(val_p_micro_w, val_r_micro_w, val_p_macro_w, val_r_macro_w)) out.write( '\n\tVal, loss = {}, valAcc = {}, val_f1_micro = {}, val_f1_macro = {}, val_f1_micro_w = {}, val_f1_macro_w = {}\n' .format(valLoss, valAcc, val_f1_micro, val_f1_macro, val_f1_micro_w, val_f1_macro_w)) out.write( '\t\t val_p_micro = {}, val_r_micro = {}, val_p_macro = {}, val_r_macro = {}\n' .format(val_p_micro, val_r_micro, val_p_macro, val_r_macro)) out.write( '\t\t val_p_micro_w = {}, val_r_micro_w = {}, val_p_macro_w = {}, val_r_macro_w = {}\n' .format(val_p_micro_w, val_r_micro_w, val_p_macro_w, val_r_macro_w)) # calculate f1 score for test (weighted/unweighted) testAcc, testLoss, test_f1_micro, test_f1_macro, test_f1_micro_w, test_f1_macro_w,\ test_p_micro, test_r_micro, test_p_macro, test_r_macro, test_p_micro_w, test_r_micro_w, test_p_macro_w, test_r_macro_w\ = self.test(sess, tag='test') print( '\n\ttest, loss = {}, testAcc = {}, test_f1_micro = {}, test_f1_macro = {}, test_f1_micro_w = {}, test_f1_macro_w = {}' .format(testLoss, testAcc, test_f1_micro, test_f1_macro, test_f1_micro_w, test_f1_macro_w)) print( '\t\t test_p_micro = {}, test_r_micro = {}, test_p_macro = {}, test_r_macro = {}' .format(test_p_micro, test_r_micro, test_p_macro, test_r_macro)) print( '\t\t test_p_micro_w = {}, test_r_micro_w = {}, test_p_macro_w = {}, test_r_macro_w = {}' .format(test_p_micro_w, test_r_micro_w, test_p_macro_w, test_r_macro_w)) out.write( '\n\ttest, loss = {}, testAcc = {}, test_f1_micro = {}, test_f1_macro = {}, test_f1_micro_w = {}, test_f1_macro_w = {}\n' .format(testLoss, testAcc, test_f1_micro, test_f1_macro, test_f1_micro_w, test_f1_macro_w)) out.write( '\t\t test_p_micro = {}, test_r_micro = {}, test_p_macro = {}, test_r_macro = {}\n' .format(test_p_micro, test_r_micro, test_p_macro, test_r_macro)) out.write( '\t\t test_p_micro_w = {}, test_r_micro_w = {}, test_p_macro_w = {}, test_r_macro_w = {}\n' .format(test_p_micro_w, test_r_micro_w, test_p_macro_w, test_r_macro_w)) out.flush() self.summaryWriter.add_summary( utils.makeSummary({"train_acc": trainAcc}), e) self.summaryWriter.add_summary( utils.makeSummary({"val_acc": valAcc}), e) # use val_f1_micro (unweighted) as metric if val_f1_micro >= current_val_f1_micro_unweighted: current_val_f1_micro_unweighted = val_f1_micro print('New val f1_micro {} at epoch {}'.format( val_f1_micro, e)) out.write('New val f1_micro {} at epoch {}\n'.format( val_f1_micro, e)) save_path = self.saver.save(sess, save_path=self.model_name) print('model saved at {}'.format(save_path)) out.write('model saved at {}\n'.format(save_path)) test_predictions = self.test(sess, tag='test2') print('Writing predictions at epoch {}'.format(e)) out.write('Writing predictions at epoch {}\n'.format(e)) test_file = self.write_predictions(test_predictions, tag='unweighted') print('Writing predictions to {}'.format(test_file)) out.write('Writing predictions to {}\n'.format(test_file)) out.flush() out.close() def write_predictions(self, predictions, tag='weighted'): test_file = self.testFile + '_' + tag with open(test_file, 'w') as file: file.write('id\tturn1\tturn2\tturn3\tlabel\n') idx2label = {v: k for k, v in self.textData.label2idx.items()} for idx, sample in enumerate(self.textData.test_samples): assert idx == sample.id file.write(str(idx) + '\t') for ind, sent in enumerate(sample.sents): file.write(' '.join(sent[:sample.length[ind]]).encode( 'ascii', 'ignore').decode('ascii') + '\t') file.write(idx2label[predictions[idx]] + '\n') return test_file def test(self, sess, tag='val'): """ for the real dev data, during test, do not use sample weights :param sess: :param tag: :return: """ if tag == 'val': print('Validating\n') batches = self.textData.val_batches else: print('Testing\n') batches = self.textData.test_batches cnt = 0 total_samples = 0 total_corrects = 0 total_loss = 0.0 all_predictions = [] all_labels = [] all_sample_weights = [] for idx, nextBatch in enumerate(tqdm(batches)): cnt += 1 total_samples += nextBatch.batch_size ops, feed_dict, labels, sample_weights = self.model.step(nextBatch, test=True) loss, predictions, corrects = sess.run(ops, feed_dict) all_predictions.extend(predictions) all_labels.extend(labels) all_sample_weights.extend(sample_weights) total_loss += loss total_corrects += corrects #break f1_micro, f1_macro, p_micro, r_micro, p_macro, r_macro = self.cal_F1( y_pred=all_predictions, y_true=all_labels) f1_micro_w, f1_macro_w, p_micro_w, r_micro_w, p_macro_w, r_macro_w =\ self.cal_F1(y_pred=all_predictions, y_true=all_labels, sample_weight=all_sample_weights) acc = total_corrects * 1.0 / total_samples if tag == 'test2': return all_predictions else: return acc, total_loss, f1_micro, f1_macro, f1_micro_w, f1_macro_w,\ p_micro, r_micro, p_macro, r_macro,\ p_micro_w, r_micro_w, p_macro_w, r_macro_w def cal_F1(self, y_pred, y_true, sample_weight=None): labels = [ self.textData.label2idx['happy'], self.textData.label2idx['sad'], self.textData.label2idx['angry'] ] # if sample_weight is not None: # sample_weight = np.asarray(sample_weight) # sample_weight = sample_weight.astype(int) p_micro, r_micro, f1_micro, _ = \ precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average='micro', labels=labels, sample_weight=sample_weight) p_macro, r_macro, f1_macro, _ = \ precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average='macro', labels=labels, sample_weight=sample_weight) return f1_micro, f1_macro, p_micro, r_micro, p_macro, r_macro
class Train: def __init__(self): self.args = None self.textData = None self.model = None self.outFile = None self.sess = None self.saver = None self.model_name = None self.model_path = None self.globalStep = 0 self.summaryDir = None self.testOutFile = None self.summaryWriter = None self.mergedSummary = None @staticmethod def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument('--resultDir', type=str, default='result', help='result directory') parser.add_argument('--testDir', type=str, default='test_result') # data location dataArgs = parser.add_argument_group('Dataset options') dataArgs.add_argument('--summaryDir', type=str, default='summaries') dataArgs.add_argument('--datasetName', type=str, default='dataset', help='a TextData object') dataArgs.add_argument('--dataDir', type=str, default='data', help='dataset directory, save pkl here') dataArgs.add_argument('--dataset', type=str, default='rotten') dataArgs.add_argument('--trainFile', type=str, default='train.txt') dataArgs.add_argument('--valFile', type=str, default='val.txt') dataArgs.add_argument('--testFile', type=str, default='test.txt') dataArgs.add_argument('--embeddingFile', type=str, default='glove.840B.300d.txt') dataArgs.add_argument('--vocabSize', type=int, default=-1, help='vocab size, use the most frequent words') # neural network options nnArgs = parser.add_argument_group('Network options') nnArgs.add_argument('--embeddingSize', type=int, default=300) nnArgs.add_argument('--hiddenSize', type=int, default=200, help='hiddenSize for RNN sentence encoder') nnArgs.add_argument('--rnnLayers', type=int, default=1) nnArgs.add_argument('--maxSteps', type=int, default=50) nnArgs.add_argument('--numClasses', type=int, default=2) nnArgs.add_argument('--skim', action='store_true') # training options trainingArgs = parser.add_argument_group('Training options') trainingArgs.add_argument('--modelPath', type=str, default='saved') trainingArgs.add_argument('--preEmbedding', action='store_true') trainingArgs.add_argument('--dropOut', type=float, default=1.0, help='dropout rate for RNN (keep prob)') trainingArgs.add_argument('--learningRate', type=float, default=0.001, help='learning rate') trainingArgs.add_argument('--batchSize', type=int, default=32, help='batch size') trainingArgs.add_argument('--skimloss', action='store_true', help='whether or not to encourage skimming') trainingArgs.add_argument('--minRead', type=int, default=2) trainingArgs.add_argument('--maxSkip', type=int, default=5) trainingArgs.add_argument('--discount', type=float, default=0.99) # max_grad_norm ## do not add dropOut in the test mode! trainingArgs.add_argument('--epochs', type=int, default=40, help='most training epochs') trainingArgs.add_argument('--device', type=str, default='/gpu:0', help='use the first GPU as default') trainingArgs.add_argument('--loadModel', action='store_true', help='whether or not to use old models') trainingArgs.add_argument('--testModel', action='store_true') trainingArgs.add_argument('--printgate', action='store_true') trainingArgs.add_argument('--nSamples', type=int, default=3) trainingArgs.add_argument('--eps', type=float, default=0.1) return parser.parse_args(args) def main(self, args=None): print('TensorFlow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.resultDir = os.path.join(self.args.resultDir, self.args.dataset) self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset) self.dataDir = os.path.join(self.args.dataDir, self.args.dataset) self.outFile = utils.constructFileName(self.args, prefix=self.resultDir) self.args.datasetName = utils.constructFileName( self.args, prefix=self.args.dataset, createDataSetName=True) datasetFileName = os.path.join(self.dataDir, self.args.datasetName) if not os.path.exists(self.resultDir): os.makedirs(self.resultDir) if not os.path.exists(self.args.modelPath): os.makedirs(self.args.modelPath) if not os.path.exists(self.summaryDir): os.makedirs(self.summaryDir) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}'.format(datasetFileName)) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.modelPath = os.path.join(self.args.modelPath, self.args.dataset) self.model_path = utils.constructFileName(self.args, prefix=self.modelPath, tag='model') if not os.path.exists(self.model_path): os.makedirs(self.model_path) self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.summaryDir) with tf.device(self.args.device): if self.args.skim: print('Skim model created') self.model = Model(self.args, self.textData) else: print('Ordinary model created!') self.model = ModelBasic(self.args, self.textData) # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() # initialize all global variables self.sess.run(init) print('All variables initialized') if not self.args.testModel: self.train(self.sess) else: self.testModel(self.sess) def train(self, sess): #sess = tf_debug.LocalCLIDebugWrapperSession(sess) print('Start training') out = open(self.outFile, 'w', 1) out.write(self.outFile + '\n') utils.writeInfo(out, self.args) current_valAcc = 0.0 for e in range(self.args.epochs): # training #trainBatches = self.textData.train_batches trainBatches = self.textData.get_batches(tag='train') totalTrainLoss = 0.0 # cnt of batches cnt = 0 total_samples = 0 total_corrects = 0 all_skip_rate = [] for idx, nextBatch in enumerate(tqdm(trainBatches)): cnt += 1 #nextBatch = trainBatches[227] self.globalStep += 1 total_samples += nextBatch.batch_size py = tf.contrib.eager.py_func(self.model.buildNetwork()) ops, feed_dict, length = self.model.step(nextBatch, test=False) sess.run(py, feed_dict=feed_dict) # skip_rate: batch_size * n_samples _, loss, predictions, corrects, skip_rate = sess.run( ops, feed_dict) all_skip_rate.extend(skip_rate.tolist()) #print(loss, idx) total_corrects += corrects totalTrainLoss += loss self.summaryWriter.add_summary( utils.makeSummary({"train_loss": loss}), self.globalStep) trainAcc = total_corrects * 1.0 / (total_samples * self.args.nSamples) train_skip_rate = np.average(all_skip_rate) print( '\nepoch = {}, Train, loss = {}, trainAcc = {}, train_skip_rate = {}' .format(e, totalTrainLoss, trainAcc, train_skip_rate)) #continue out.write( '\nepoch = {}, loss = {}, trainAcc = {}, train_skip_rate = {}\n' .format(e, totalTrainLoss, trainAcc, train_skip_rate)) out.flush() valAcc, valLoss, val_skip_rate = self.test(sess, tag='val') testAcc, testLoss, test_skip_rate = self.test(sess, tag='test') print('\tVal, loss = {}, valAcc = {}, val_skip_rate = {}'.format( valLoss, valAcc, val_skip_rate)) out.write( '\tVal, loss = {}, valAcc = {}, val_skip_rate = {}\n'.format( valLoss, valAcc, val_skip_rate)) print( '\tTest, loss = {}, testAcc = {}, test_skip_rate = {}'.format( testLoss, testAcc, test_skip_rate)) out.write('\tTest, loss = {}, testAcc = {}, test_skip_rate = {}\n'. format(testLoss, testAcc, test_skip_rate)) self.summaryWriter.add_summary( utils.makeSummary({"train_acc": trainAcc}), e) self.summaryWriter.add_summary( utils.makeSummary({"val_acc": valAcc}), e) self.summaryWriter.add_summary( utils.makeSummary({"test_acc": testAcc}), e) self.summaryWriter.add_summary( utils.makeSummary({"train_skip_rate": train_skip_rate}), e) self.summaryWriter.add_summary( utils.makeSummary({"val_skip_rate": val_skip_rate}), e) self.summaryWriter.add_summary( utils.makeSummary({"test_skip_rate": test_skip_rate}), e) # we do not use cross val currently, just train, then evaluate if valAcc >= current_valAcc: # with open('skip_train.pkl', 'wb') as f: # p.dump(all_skip_rate, f) current_valAcc = valAcc print('New valAcc {} at epoch {}'.format(valAcc, e)) out.write('New valAcc {} at epoch {}\n'.format(valAcc, e)) save_path = self.saver.save(sess, save_path=self.model_name) print('model saved at {}'.format(save_path)) out.write('model saved at {}\n'.format(save_path)) out.flush() out.close() def test(self, sess, tag='val'): if tag == 'val': print('Validating\n') batches = self.textData.val_batches else: print('Testing\n') batches = self.textData.test_batches cnt = 0 total_samples = 0 total_corrects = 0 total_loss = 0.0 all_predictions = [] all_skip_rate = [] for idx, nextBatch in enumerate(tqdm(batches)): cnt += 1 total_samples += nextBatch.batch_size ops, feed_dict, length = self.model.step(nextBatch, test=True) loss, predictions, corrects, skip_rate = sess.run(ops, feed_dict) all_skip_rate.extend(skip_rate) all_predictions.extend(predictions) total_loss += loss total_corrects += corrects total_length = np.sum(length) # plt.hist(all_skip_rate) # plt.savefig('tmp.png') # print(np.average(all_skip_rate)) acc = total_corrects * 1.0 / total_samples return acc, total_loss, np.average(all_skip_rate) def testModel(self, sess): acc, total_loss, _ = self.test(sess, tag='test') print('acc = {}, total_loss = {}'.format(acc, total_loss))
def main(self, args=None): print('TensorFlow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.process_gates_files() self.resultDir = os.path.join(self.args.resultDir, self.args.dataset) self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset) self.dataDir = os.path.join(self.args.dataDir, self.args.dataset) self.outFile = utils.constructFileName(self.args, prefix=self.resultDir) self.args.datasetName = utils.constructFileName( self.args, prefix=self.args.dataset, createDataSetName=True) if self.args.next: self.args.datasetName = 'next_' + self.args.datasetName else: self.args.datasetName = 'current_' + self.args.datasetName datasetFileName = os.path.join(self.dataDir, self.args.datasetName) if not os.path.exists(self.resultDir): os.makedirs(self.resultDir) if not os.path.exists(self.args.modelPath): os.makedirs(self.args.modelPath) if not os.path.exists(self.summaryDir): os.makedirs(self.summaryDir) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}'.format(datasetFileName)) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) # init hard gates self.init_hard_gates(self.args.threshold, self.args.percent) sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.modelPath = os.path.join(self.args.modelPath, self.args.dataset) self.model_path = utils.constructFileName(self.args, prefix=self.modelPath, tag='model') if not os.path.exists(self.model_path): os.makedirs(self.model_path) self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.summaryDir) with tf.device(self.args.device): if self.args.skim: print('Skim model created') self.model = Model(self.args, self.textData) else: print('Ordinary model created!') self.model = ModelBasic(self.args, self.textData) # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() # initialize all global variables self.sess.run(init) print('All variables initialized') if not self.args.testModel: self.train(self.sess) else: self.testModel(self.sess)
class Train: def __init__(self): self.args = None self.textData = None self.model = None self.outFile = None self.sess = None self.saver = None self.model_name = None self.model_path = None self.globalStep = 0 self.summaryDir = None self.testOutFile = None self.summaryWriter = None self.mergedSummary = None @staticmethod def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument('--resultDir', type=str, default='result', help='result directory') parser.add_argument('--testDir', type=str, default='test_result') # data location dataArgs = parser.add_argument_group('Dataset options') dataArgs.add_argument('--summaryDir', type=str, default='summaries') dataArgs.add_argument('--datasetName', type=str, default='dataset', help='a TextData object') dataArgs.add_argument('--dataDir', type=str, default='data', help='dataset directory, save pkl here') dataArgs.add_argument('--dataset', type=str, default='rotten') dataArgs.add_argument('--trainFile', type=str, default='train.txt') dataArgs.add_argument('--valFile', type=str, default='val.txt') dataArgs.add_argument('--testFile', type=str, default='test.txt') dataArgs.add_argument('--embeddingFile', type=str, default='glove.840B.300d.txt') dataArgs.add_argument('--vocabSize', type=int, default=-1, help='vocab size, use the most frequent words') dataArgs.add_argument('--gatesFileTrain', type=str, default='gates_train_') dataArgs.add_argument('--gatesFileVal', type=str, default='gates_val_') dataArgs.add_argument('--gatesFileTest', type=str, default='gates_test_') # neural network options nnArgs = parser.add_argument_group('Network options') nnArgs.add_argument('--embeddingSize', type=int, default=300) nnArgs.add_argument('--hiddenSize', type=int, default=200, help='hiddenSize for RNN sentence encoder') nnArgs.add_argument('--rnnLayers', type=int, default=1) nnArgs.add_argument('--maxSteps', type=int, default=50) nnArgs.add_argument('--numClasses', type=int, default=2) nnArgs.add_argument('--skim', action='store_true') # training options trainingArgs = parser.add_argument_group('Training options') trainingArgs.add_argument('--modelPath', type=str, default='saved') trainingArgs.add_argument('--preEmbedding', action='store_true') trainingArgs.add_argument('--dropOut', type=float, default=1.0, help='dropout rate for RNN (keep prob)') trainingArgs.add_argument('--learningRate', type=float, default=0.001, help='learning rate') trainingArgs.add_argument('--batchSize', type=int, default=32, help='batch size') trainingArgs.add_argument('--skimloss', action='store_true', help='whether or not to encourage skimming') trainingArgs.add_argument( '--minRead', type=int, default=8, help='minimum number of tokens read before a jump') trainingArgs.add_argument('--maxSkip', type=int, default=5, help='maximum of jumping steps in a jump') trainingArgs.add_argument( '--maxJump', type=int, default=-1, help= 'maximum number of jumps in a sequence, -1 indicates we are not using acl style' ) trainingArgs.add_argument('--discount', type=float, default=0.99) trainingArgs.add_argument('--epochs', type=int, default=300, help='most training epochs') trainingArgs.add_argument('--device', type=str, default='/gpu:0', help='use the first GPU as default') trainingArgs.add_argument('--loadModel', action='store_true', help='whether or not to use old models') trainingArgs.add_argument('--testModel', action='store_true') trainingArgs.add_argument('--printgate', action='store_true') trainingArgs.add_argument('--nSamples', type=int, default=3) trainingArgs.add_argument('--eps', type=float, default=0.1) trainingArgs.add_argument('--random', action='store_true') trainingArgs.add_argument('--sparse', type=float, default=10.0, help='coefficient for sparse penalty') trainingArgs.add_argument( '--percent', action='store_true', help='whether or not use percentage for hardgate init') trainingArgs.add_argument('--threshold', type=float, default=0.5, help='for hardgate init') trainingArgs.add_argument('--transferEpochs', type=int, default=150, help='number of epochs for transfering') trainingArgs.add_argument('--next', action='store_true', help='style of predicting gates') trainingArgs.add_argument( '--all', action='store_true', help='whether or not to transfer for val&test') return parser.parse_args(args) def process_gates_files(self): if self.args.next: self.args.gatesFileTrain = self.args.gatesFileTrain + 'next.csv' self.args.gatesFileVal = self.args.gatesFileVal + 'next.csv' self.args.gatesFileTest = self.args.gatesFileTest + 'next.csv' else: self.args.gatesFileTrain = self.args.gatesFileTrain + 'current.csv' self.args.gatesFileVal = self.args.gatesFileVal + 'current.csv' self.args.gatesFileTest = self.args.gatesFileTest + 'current.csv' def main(self, args=None): print('TensorFlow version {}'.format(tf.VERSION)) # initialize args self.args = self.parse_args(args) self.process_gates_files() self.resultDir = os.path.join(self.args.resultDir, self.args.dataset) self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset) self.dataDir = os.path.join(self.args.dataDir, self.args.dataset) self.outFile = utils.constructFileName(self.args, prefix=self.resultDir) self.args.datasetName = utils.constructFileName( self.args, prefix=self.args.dataset, createDataSetName=True) if self.args.next: self.args.datasetName = 'next_' + self.args.datasetName else: self.args.datasetName = 'current_' + self.args.datasetName datasetFileName = os.path.join(self.dataDir, self.args.datasetName) if not os.path.exists(self.resultDir): os.makedirs(self.resultDir) if not os.path.exists(self.args.modelPath): os.makedirs(self.args.modelPath) if not os.path.exists(self.summaryDir): os.makedirs(self.summaryDir) if not os.path.exists(datasetFileName): self.textData = TextData(self.args) with open(datasetFileName, 'wb') as datasetFile: p.dump(self.textData, datasetFile) print('dataset created and saved to {}'.format(datasetFileName)) else: with open(datasetFileName, 'rb') as datasetFile: self.textData = p.load(datasetFile) print('dataset loaded from {}'.format(datasetFileName)) # init hard gates self.init_hard_gates(self.args.threshold, self.args.percent) sessConfig = tf.ConfigProto(allow_soft_placement=True) sessConfig.gpu_options.allow_growth = True self.modelPath = os.path.join(self.args.modelPath, self.args.dataset) self.model_path = utils.constructFileName(self.args, prefix=self.modelPath, tag='model') if not os.path.exists(self.model_path): os.makedirs(self.model_path) self.model_name = os.path.join(self.model_path, 'model') self.sess = tf.Session(config=sessConfig) # summary writer self.summaryDir = utils.constructFileName(self.args, prefix=self.summaryDir) with tf.device(self.args.device): if self.args.skim: print('Skim model created') self.model = Model(self.args, self.textData) else: print('Ordinary model created!') self.model = ModelBasic(self.args, self.textData) # saver can only be created after we have the model self.saver = tf.train.Saver() self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph) self.mergedSummary = tf.summary.merge_all() if self.args.loadModel: # load model from disk if not os.path.exists(self.model_path): print('model does not exist on disk!') print(self.model_path) exit(-1) self.saver.restore(sess=self.sess, save_path=self.model_name) print('Variables loaded from disk {}'.format(self.model_name)) else: init = tf.global_variables_initializer() # initialize all global variables self.sess.run(init) print('All variables initialized') if not self.args.testModel: self.train(self.sess) else: self.testModel(self.sess) def train(self, sess): #sess = tf_debug.LocalCLIDebugWrapperSession(sess) print('Start training') out = open(self.outFile, 'w', 1) out.write(self.outFile + '\n') utils.writeInfo(out, self.args) current_valAcc = 0.0 for e in range(self.args.epochs): # training #trainBatches = self.textData.train_batches trainBatches = self.textData.get_batches(tag='train') totalTrainLoss = 0.0 # cnt of batches cnt = 0 total_samples = 0 total_corrects = 0 all_skip_rate = [] all_correct_predicted_inference_skips = [] total_valids = 0 if e == self.args.transferEpochs + 1: print('RL begins!') out.write('RL begins\n') for idx, nextBatch in enumerate(tqdm(trainBatches)): cnt += 1 #nextBatch = trainBatches[227] self.globalStep += 1 total_samples += nextBatch.batch_size if e > self.args.transferEpochs + 1: ops, feed_dict, length = self.model.step( nextBatch, test=False, is_transfering=False) else: # only transfer at first xxx epochs ops, feed_dict, length = self.model.step( nextBatch, test=False, is_transfering=True) # skip_rate: batch_size * n_samples _, loss, predictions, corrects, skip_rate, correct_predicted_inference_skips, n_valids_sum = sess.run( ops, feed_dict) all_correct_predicted_inference_skips.extend( correct_predicted_inference_skips.tolist()) all_skip_rate.extend(skip_rate.tolist()) total_valids += n_valids_sum #print(loss, idx) total_corrects += corrects totalTrainLoss += loss self.summaryWriter.add_summary( utils.makeSummary({"train_loss": loss}), self.globalStep) trainAcc = total_corrects * 1.0 / (total_samples * self.args.nSamples) train_skip_rate = np.average(all_skip_rate) train_skip_acc = np.sum( all_correct_predicted_inference_skips) / total_valids print( '\nepoch = {}, Train, loss = {}, trainAcc = {}, train_skip_rate = {}, train_skip_acc = {}' .format(e, totalTrainLoss, trainAcc, train_skip_rate, train_skip_acc)) out.write( '\nepoch = {}, loss = {}, trainAcc = {}, train_skip_rate = {}, train_skip_acc = {}\n' .format(e, totalTrainLoss, trainAcc, train_skip_rate, train_skip_acc)) out.flush() valAcc, valLoss, val_skip_rate, val_skip_acc = self.test(sess, tag='val') testAcc, testLoss, test_skip_rate, test_skip_acc = self.test( sess, tag='test') print( '\tVal, loss = {}, valAcc = {}, val_skip_rate = {}, val_skip_acc = {}' .format(valLoss, valAcc, val_skip_rate, val_skip_acc)) out.write( '\tVal, loss = {}, valAcc = {}, val_skip_rate = {}, val_skip_acc = {}\n' .format(valLoss, valAcc, val_skip_rate, val_skip_acc)) print( '\tTest, loss = {}, testAcc = {}, test_skip_rate = {}, test_skip_acc = {}' .format(testLoss, testAcc, test_skip_rate, test_skip_acc)) out.write( '\tTest, loss = {}, testAcc = {}, test_skip_rate = {}, test_skip_acc = {}\n' .format(testLoss, testAcc, test_skip_rate, test_skip_acc)) self.summaryWriter.add_summary( utils.makeSummary({"train_acc": trainAcc}), e) self.summaryWriter.add_summary( utils.makeSummary({"val_acc": valAcc}), e) self.summaryWriter.add_summary( utils.makeSummary({"test_acc": testAcc}), e) self.summaryWriter.add_summary( utils.makeSummary({"train_skip_rate": train_skip_rate}), e) self.summaryWriter.add_summary( utils.makeSummary({"val_skip_rate": val_skip_rate}), e) self.summaryWriter.add_summary( utils.makeSummary({"test_skip_rate": test_skip_rate}), e) self.summaryWriter.add_summary( utils.makeSummary({"train_skip_acc": train_skip_acc}), e) self.summaryWriter.add_summary( utils.makeSummary({"val_skip_acc": val_skip_acc}), e) self.summaryWriter.add_summary( utils.makeSummary({"test_skip_acc": test_skip_acc}), e) # we do not use cross val currently, just train, then evaluate if e == self.args.transferEpochs + 1: current_valAcc = 0 if valAcc >= current_valAcc: # with open('skip_train.pkl', 'wb') as f: # p.dump(all_skip_rate, f) current_valAcc = valAcc if e < self.args.transferEpochs: print('New valAcc {} at epoch {}'.format(valAcc, e)) out.write('New valAcc {} at epoch {}\n'.format(valAcc, e)) else: print('New valAcc2 {} at epoch {}'.format(valAcc, e)) out.write('New valAcc2 {} at epoch {}\n'.format(valAcc, e)) # save_path = self.saver.save(sess, save_path=self.model_name) # print('model saved at {}'.format(save_path)) # out.write('model saved at {}\n'.format(save_path)) out.flush() out.close() def test(self, sess, tag='val'): if tag == 'val': print('Validating\n') batches = self.textData.val_batches else: print('Testing\n') batches = self.textData.test_batches cnt = 0 total_samples = 0 total_corrects = 0 total_loss = 0.0 all_predictions = [] all_skip_rate = [] total_valids = 0 all_correct_predicted_inference_skips = [] for idx, nextBatch in enumerate(tqdm(batches)): cnt += 1 total_samples += nextBatch.batch_size if self.args.all: is_transfering = True else: is_transfering = False ops, feed_dict, length = self.model.step( nextBatch, test=True, is_transfering=is_transfering) loss, predictions, corrects, skip_rate, correct_predicted_inference_skips, n_valids_sum = sess.run( ops, feed_dict) all_correct_predicted_inference_skips.extend( correct_predicted_inference_skips.tolist()) all_skip_rate.extend(skip_rate.tolist()) total_valids += n_valids_sum all_predictions.extend(predictions) total_loss += loss total_corrects += corrects total_length = np.sum(length) # plt.hist(all_skip_rate) # plt.savefig('tmp.png') # print(np.average(all_skip_rate)) predicted_skip_acc = np.sum( all_correct_predicted_inference_skips) / total_valids acc = total_corrects * 1.0 / total_samples return acc, total_loss, np.average(all_skip_rate), predicted_skip_acc def testModel(self, sess): acc, total_loss, _, _ = self.test(sess, tag='test') print('acc = {}, total_loss = {}'.format(acc, total_loss)) def init_hard_gates(self, threshold, percent): train_batch = self.textData.train_batches for batch in train_batch: batch.init_hard_gates(threshold=threshold, percent=percent, max_skip=self.args.maxSkip) val_batch = self.textData.val_batches for batch in val_batch: batch.init_hard_gates(threshold=threshold, percent=percent, max_skip=self.args.maxSkip) test_batch = self.textData.test_batches for batch in test_batch: batch.init_hard_gates(threshold=threshold, percent=percent, max_skip=self.args.maxSkip)