Esempio n. 1
0
class Predict:

    def __init__(self):
        self.args = None


        self.textData = None
        self.model = None
        self.outFile = None
        self.sess = None
        self.saver = None
        self.model_name = None
        self.model_path = None
        self.globalStep = 0
        self.summaryDir = None

        self.summaryWriter = None
        self.mergedSummary = None

    @staticmethod
    def parse_args(args):

        parser = argparse.ArgumentParser()

        parser.add_argument('--resultDir', type=str, default='result', help='result directory')
        # data location
        dataArgs = parser.add_argument_group('Dataset options')

        dataArgs.add_argument('--summaryDir', type=str, default='./summaries')
        dataArgs.add_argument('--dataDir', type=str, default='data', help='dataset directory, save pkl here')
        dataArgs.add_argument('--datasetName', type=str, default='dataset', help='a TextData object')
        dataArgs.add_argument('--trainFile', type=str, default='sentences.train')

        # use val file for generation task
        dataArgs.add_argument('--valFile', type=str, default='sentences.continuation')
        dataArgs.add_argument('--testFile', type=str, default='sentences.eval')
        dataArgs.add_argument('--embedFile', type=str, default='wordembeddings-dim100.word2vec')
        dataArgs.add_argument('--doTest', action='store_true')
        dataArgs.add_argument('--vocabSize', type=int, default=20000, help='vocab size, use the most frequent words')
        # neural network options
        nnArgs = parser.add_argument_group('Network options')
        nnArgs.add_argument('--embeddingSize', type=int, default=100)
        nnArgs.add_argument('--hiddenSize', type=int, default=512, help='hiddenSize for RNN sentence encoder')
        nnArgs.add_argument('--oriSize', type=int, default=512)
        nnArgs.add_argument('--rnnLayers', type=int, default=1)
        nnArgs.add_argument('--maxSteps', type=int, default=30)
        nnArgs.add_argument('--project', action='store_true')
        # training options
        trainingArgs = parser.add_argument_group('Training options')
        trainingArgs.add_argument('--modelPath', type=str, default='saved')
        trainingArgs.add_argument('--preEmbedding', action='store_true')
        trainingArgs.add_argument('--dropOut', type=float, default=0.8, help='dropout rate for RNN (keep prob)')
        trainingArgs.add_argument('--learningRate', type=float, default=0.001, help='learning rate')
        trainingArgs.add_argument('--batchSize', type=int, default=64, help='batch size')
        # max_grad_norm
        trainingArgs.add_argument('--maxGradNorm', type=int, default=5)
        ## do not add dropOut in the test mode!
        trainingArgs.add_argument('--test', action='store_true', help='if in test mode')
        trainingArgs.add_argument('--epochs', type=int, default=100, help='most training epochs')
        trainingArgs.add_argument('--device', type=str, default='/gpu:0', help='use the first GPU as default')
        trainingArgs.add_argument('--loadModel', action='store_true', help='whether or not to use old models')
        trainingArgs.add_argument('--testModel', action='store_true', help='do not train, only test')
        trainingArgs.add_argument('--generate', action='store_true', help='for task 2, generate sentences greedily')
        # note: we can set this number larger than 20, then we truncated it when handing in
        trainingArgs.add_argument('--maxGenerateLength', type=int, default=25, help='maximum length when generating sentences')
        trainingArgs.add_argument('--writePerplexity', action='store_true')
        return parser.parse_args(args)

    def main(self, args=None):
        print('Tensorflow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)


        self.outFile = utils.constructFileName(self.args, prefix=self.args.resultDir)
        self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataDir)
        datasetFileName = os.path.join(self.args.dataDir, self.args.datasetName)


        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}'.format(datasetFileName))
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.model_path = utils.constructFileName(self.args, prefix=self.args.modelPath, tag='model')
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args, prefix=self.args.summaryDir)


        with tf.device(self.args.device):
            if not self.args.generate:
                self.model = Model(self.args, self.textData)
            else:
                print('Creating model for generation')
                self.model = ModelG(self.args, self.textData)
            params = tf.trainable_variables()
            print('Model created')

            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                # initialize all global variables
                self.sess.run(init)
                print('All variables initialized')

            if not self.args.testModel and not self.args.generate and not self.args.writePerplexity:
                self.train(self.sess)
            elif self.args.writePerplexity:
                self.test(self.sess, tag='test')
            elif self.args.generate:
                self.generate(self.sess)
            else:
                self.test_model()



    def generate(self, sess):
        print('generating!')
        batches = self.textData.get_batches(tag='val')
        all_sents = []
        for nextBatch in tqdm(batches):
            # set dummy initial values for predictions
            predictions = np.zeros(shape=(nextBatch.batch_size), dtype=np.int32)
            sents = []
            for i in range(nextBatch.batch_size):
                sents.append([self.textData.BOS_WORD])

            for time_step in range(self.args.maxGenerateLength):
                ops, feed_dict, sents = self.model.step(nextBatch, predictions=predictions, time_step=time_step, sents=sents)
                predictions = sess.run(ops, feed_dict)

            all_sents.extend(sents)

        with open('generated.txt', 'w') as file:
            for sent in all_sents:
                sentence = ' '.join(sent)
                file.write(sentence+'\n')


    def train(self, sess):
        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_val_loss = np.inf

        for e in range(self.args.epochs):
            # training
            trainBatches = self.textData.get_batches(tag='train')
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_steps = 0
            for nextBatch in tqdm(trainBatches):
                cnt += 1
                self.globalStep += 1

                for sample in nextBatch.samples:
                    total_steps += sample.length

                ops, feed_dict = self.model.step(nextBatch, test=False)

                _, loss, trainPerplexity = sess.run(ops, feed_dict)

                totalTrainLoss += loss

                # average across samples in this step
                trainPerplexity = np.mean(trainPerplexity)
                self.summaryWriter.add_summary(utils.makeSummary({"trainLoss": loss}), self.globalStep)
                self.summaryWriter.add_summary(utils.makeSummary({"trainPerplexity": trainPerplexity}), self.globalStep)

            # compute perplexity over all samples in an epoch
            trainPerplexity = np.exp(totalTrainLoss/total_steps)

            print('\nepoch = {}, Train, loss = {}, perplexity = {}'.
                  format(e, totalTrainLoss, trainPerplexity))
            out.write('\nepoch = {}, loss = {}, perplexity = {}\n'.
                  format(e, totalTrainLoss, trainPerplexity))
            out.flush()

            valLoss, val_num = self.test(sess, tag='val')

            testLoss, test_num = self.test(sess, tag='test')

            valPerplexity = np.exp(valLoss/val_num)
            testPerplexity = np.exp(testLoss/test_num)

            print('Val, loss = {}, perplexity = {}'.
                  format(valLoss, valPerplexity))
            out.write('Val, loss = {}, perplexity = {}\n'.
                  format(valLoss, valPerplexity))

            print('Test, loss = {}, perplexity = {}'.
                  format(testLoss, testPerplexity))
            out.write('Test, loss = {}, perplexity = {}\n'.
                  format(testLoss, testPerplexity))

            # we do not use cross val currently, just train, then evaluate
            #if True:
            if valLoss < current_val_loss:
                current_val_loss = valLoss
                print('New val loss {} at epoch {}'.format(valLoss, e))
                out.write('New val loss {} at epoch {}\n'.format(valLoss, e))
                save_path = self.saver.save(sess, save_path=self.model_name)
                print('model saved at {}'.format(save_path))
                out.write('model saved at {}\n'.format(save_path))

            out.flush()
        out.close()


    def write_perplexity(self, batch, perplexity):
        assert len(batch.samples) == len(perplexity)
        input_ = []
        for sample in batch.samples:
            sent = []
            for word_id in sample.input_:
                word = self.textData.id2word[word_id]
                sent.append(word)
                if word == self.textData.EOS_WORD:
                    break
            sent = ' '.join(sent).strip()
            input_.append(sent)
        with open('perplexity.txt', 'a') as file:
            for idx, sent in enumerate(input_):
                file.write(sent+'\t'+str(perplexity[idx])+'\n')

    def test(self, sess, tag = 'val'):
        if tag == 'val':
            print('Validating\n')
            batches = self.textData.val_batches
        else:
            print('Testing\n')
            batches = self.textData.test_batches

        cnt = 0

        total_loss = 0.0
        total_steps = 0
        for idx, nextBatch in enumerate(tqdm(batches)):
            cnt += 1
            ops, feed_dict = self.model.step(nextBatch, test=True)
            loss, perplexity = sess.run(ops, feed_dict)

            total_loss += loss
            for sample in nextBatch.samples:
                total_steps += sample.length

            if self.args.writePerplexity:
                self.write_perplexity(nextBatch, perplexity)
        return total_loss, total_steps


    def test_model(self):
        # TODO: placeholder, this function is useless in this implementation, ignore
        pass
Esempio n. 2
0
    def main(self, args=None):
        print('Tensorflow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)


        self.outFile = utils.constructFileName(self.args, prefix=self.args.resultDir)
        self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataDir)
        datasetFileName = os.path.join(self.args.dataDir, self.args.datasetName)


        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}'.format(datasetFileName))
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.model_path = utils.constructFileName(self.args, prefix=self.args.modelPath, tag='model')
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args, prefix=self.args.summaryDir)


        with tf.device(self.args.device):
            if not self.args.generate:
                self.model = Model(self.args, self.textData)
            else:
                print('Creating model for generation')
                self.model = ModelG(self.args, self.textData)
            params = tf.trainable_variables()
            print('Model created')

            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                # initialize all global variables
                self.sess.run(init)
                print('All variables initialized')

            if not self.args.testModel and not self.args.generate and not self.args.writePerplexity:
                self.train(self.sess)
            elif self.args.writePerplexity:
                self.test(self.sess, tag='test')
            elif self.args.generate:
                self.generate(self.sess)
            else:
                self.test_model()
Esempio n. 3
0
class Predict:

    def __init__(self):
        self.args = None


        self.textData = None
        self.model = None
        self.outFile = None
        self.sess = None
        self.saver = None
        self.model_name = None
        self.model_path = None
        self.globalStep = 0
        self.summaryDir = None
        self.testOutFile = None
        self.summaryWriter = None
        self.mergedSummary = None

    @staticmethod
    def parse_args(args):

        parser = argparse.ArgumentParser()

        parser.add_argument('--resultDir', type=str, default='result', help='result directory')
        # data location
        dataArgs = parser.add_argument_group('Dataset options')

        dataArgs.add_argument('--random', type=int, default=3)
        dataArgs.add_argument('--backward', type=int, default=3)
        dataArgs.add_argument('--near', type=int, default=3)

        dataArgs.add_argument('--randomFile', type=str, default='random.csv')
        dataArgs.add_argument('--backwardFile', type=str, default='backward.csv')
        dataArgs.add_argument('--nearFile', type=str, default='near.csv')

        dataArgs.add_argument('--summaryDir', type=str, default='summaries')
        dataArgs.add_argument('--datasetName', type=str, default='dataset', help='a TextData object')

        dataArgs.add_argument('--dataDir', type=str, default='data', help='dataset directory, save pkl here')
        dataArgs.add_argument('--trainFile', type=str, default='train_dummy.csv')
        dataArgs.add_argument('--valFile', type=str, default='val.csv')
        dataArgs.add_argument('--testFile', type=str, default='test.csv')
        dataArgs.add_argument('--ethTest', type=str, default='eth_test.csv')

        dataArgs.add_argument('--embedFile', type=str, default='glove.840B.300d.txt')
        dataArgs.add_argument('--vocabSize', type=int, default=-1, help='vocab size, use the most frequent words')

        # neural network options
        nnArgs = parser.add_argument_group('Network options')
        nnArgs.add_argument('--embeddingSize', type=int, default=300)
        nnArgs.add_argument('--nSentences', type=int, default=6)
        nnArgs.add_argument('--hiddenSize', type=int, default=512, help='hiddenSize for RNN sentence encoder')
        nnArgs.add_argument('--attSize', type=int, default=512)
        nnArgs.add_argument('--sentenceAttSize', type=int, default=512)
        nnArgs.add_argument('--rnnLayers', type=int, default=1)
        nnArgs.add_argument('--maxSteps', type=int, default=30)
        nnArgs.add_argument('--numClasses', type=int, default=2)
        nnArgs.add_argument('--ffnnLayers', type=int, default=2)
        nnArgs.add_argument('--ffnnSize', type=int, default=300)
        nnArgs.add_argument('--pffnnLayers', type=int, default=2)
        nnArgs.add_argument('--pffnnSize', type=int, default=512)
        nnArgs.add_argument('--nn', type=str, default='att')
        # training options
        trainingArgs = parser.add_argument_group('Training options')
        trainingArgs.add_argument('--dataProcess', action='store_true')
        trainingArgs.add_argument('--modelPath', type=str, default='saved')
        trainingArgs.add_argument('--preEmbedding', action='store_true')
        trainingArgs.add_argument('--dropOut', type=float, default=0.8, help='dropout rate for RNN (keep prob)')
        trainingArgs.add_argument('--learningRate', type=float, default=0.001, help='learning rate')
        trainingArgs.add_argument('--batchSize', type=int, default=100, help='batch size')
        # max_grad_norm
        ## do not add dropOut in the test mode!
        trainingArgs.add_argument('--twitterTest', action='store_true', help='whether or not do test in twitter dataset')
        trainingArgs.add_argument('--epochs', type=int, default=200, help='most training epochs')
        trainingArgs.add_argument('--device', type=str, default='/gpu:0', help='use the first GPU as default')
        trainingArgs.add_argument('--loadModel', action='store_true', help='whether or not to use old models')
        trainingArgs.add_argument('--testModel', action='store_true')
        trainingArgs.add_argument('--testTag', type=str, default='test')
        return parser.parse_args(args)

    def main(self, args=None):
        print('Tensorflow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)


        self.args.resultDir = self.args.nn +'_' + self.args.resultDir
        self.args.modelPath = self.args.nn +'_' + self.args.modelPath
        self.args.summaryDir = self.args.nn +'_' + self.args.summaryDir

        if not os.path.exists(self.args.resultDir):
            os.makedirs(self.args.resultDir)

        if not os.path.exists(self.args.modelPath):
            os.makedirs(self.args.modelPath)

        if not os.path.exists(self.args.summaryDir):
            os.makedirs(self.args.summaryDir)

        self.outFile = utils.constructFileName(self.args, prefix=self.args.resultDir)
        self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataDir)
        datasetFileName = os.path.join(self.args.dataDir, self.args.datasetName)

        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}'.format(datasetFileName))
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        if self.args.dataProcess:
            exit(0)

        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.model_path = utils.constructFileName(self.args, prefix=self.args.modelPath, tag='model')
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args, prefix=self.args.summaryDir)

        with tf.device(self.args.device):
            if self.args.nn == 'vanilla':
                print('Creating vanilla model!')
                self.model = Model_vanilla(self.args, self.textData)
            elif self.args.nn == 'att':
                print('Creating model with sentences and words attention!')
                self.model = Model_att(self.args, self.textData)
            else:
                print('Creating model with only words attention!')
                self.model = Model_satt(self.args, self.textData)
            print('Model created')

            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir, self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                # initialize all global variables
                self.sess.run(init)
                print('All variables initialized')

            if not self.args.testModel:
                self.train(self.sess)
            else:
                self.testModel(sess=self.sess, tag=self.args.testTag)

    def train(self, sess):
        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_valAcc = 0.0

        for e in range(self.args.epochs):
            # training
            #trainBatches = self.textData.train_batches
            trainBatches = self.textData.get_batches(tag='train')
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_samples = 0
            total_corrects = 0
            for nextBatch in tqdm(trainBatches):
                cnt += 1
                self.globalStep += 1

                total_samples += nextBatch.batch_size
                ops, feed_dict = self.model.step(nextBatch, test=False)

                _, loss, predictions, corrects = sess.run(ops, feed_dict)
                total_corrects += corrects
                totalTrainLoss += loss

                # average across samples in this step
                self.summaryWriter.add_summary(utils.makeSummary({"trainLoss": loss}), self.globalStep)
            # compute perplexity over all samples in an epoch
            trainAcc = total_corrects*1.0/total_samples
            print('\nepoch = {}, Train, loss = {}, trainAcc = {}'.
                  format(e, totalTrainLoss, trainAcc))
            out.write('\nepoch = {}, loss = {}, trainAcc = {}\n'.
                  format(e, totalTrainLoss, trainAcc))
            out.flush()
            valAcc, valLoss = self.test(sess, tag='val')

            print('Val, loss = {}, valAcc = {}'.
                  format(valLoss, valAcc))
            out.write('Val, loss = {}, valAcc = {}\n'.
                  format(valLoss, valAcc))

            testAcc, testLoss = self.test(sess, tag='test')
            print('Test, loss = {}, testAcc = {}'.
                  format(testLoss, testAcc))
            out.write('Test, loss = {}, testAcc = {}\n'.
                  format(testLoss, testAcc))

            out.flush()

            # we do not use cross val currently, just train, then evaluate
            if valAcc >= current_valAcc:
                current_valAcc = valAcc
                print('New valAcc {} at epoch {}'.format(valAcc, e))
                out.write('New valAcc {} at epoch {}\n'.format(valAcc, e))
                save_path = self.saver.save(sess, save_path=self.model_name)
                print('model saved at {}'.format(save_path))
                out.write('model saved at {}\n'.format(save_path))

            out.flush()
        out.close()

    def createETH(self):
        samples = self.textData._create_samples(os.path.join(self.args.dataDir, self.args.ethTest))
        batches = self.textData._create_batch(samples)

        return batches

    def testModel(self, sess, tag='test'):
        if tag == 'test':
            print('Using original test set to test the performance')
            out_file_name = 'original_test_results.csv'
            batches = self.textData.test_batches
        else:
            print('Using ETH test set to test the performance')
            out_file_name = 'ETH_test_results.csv'
            batches = self.createETH()

        cnt = 0

        total_samples = 0
        total_corrects = 0
        total_loss = 0.0
        all_predictions = []
        for idx, nextBatch in enumerate(tqdm(batches)):
            cnt += 1

            total_samples += nextBatch.batch_size
            ops, feed_dict = self.model.step(nextBatch, test=True)

            loss, predictions, corrects = sess.run(ops, feed_dict)
            all_predictions.extend(predictions)
            total_loss += loss
            total_corrects += corrects


        with open(out_file_name, 'w') as file:
            for prediction in all_predictions:
                file.write(str(prediction) + '\n')
        acc = total_corrects*1.0/total_samples
        print(acc)
        print('Test Over!')


    def test(self, sess, tag = 'val'):
        if tag == 'val':
            print('Validating\n')
            batches = self.textData.val_batches
        else:
            print('Testing\n')
            batches = self.textData.test_batches

        cnt = 0

        total_samples = 0
        total_corrects = 0
        total_loss = 0.0
        all_predictions = []
        for idx, nextBatch in enumerate(tqdm(batches)):
            cnt += 1

            total_samples += nextBatch.batch_size
            ops, feed_dict = self.model.step(nextBatch, test=True)

            loss, predictions, corrects = sess.run(ops, feed_dict)
            all_predictions.extend(predictions)
            total_loss += loss
            total_corrects += corrects

        acc = total_corrects*1.0/total_samples
        return acc, total_loss
Esempio n. 4
0
	def main(self, args=None):
		print('TensorFlow version {}'.format(tf.VERSION))

		# initialize args
		self.args = self.parse_args(args)

		self.resultDir = os.path.join(self.args.resultDir, self.args.dataset)
		self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset)
		self.dataDir = os.path.join(self.args.dataDir, self.args.dataset)

		self.outFile = utils.constructFileName(self.args, prefix=self.resultDir)
		self.args.datasetName = utils.constructFileName(self.args, prefix=self.args.dataset, createDataSetName=True)
		datasetFileName = os.path.join(self.dataDir, self.args.datasetName)

		if not os.path.exists(self.resultDir):
			os.makedirs(self.resultDir)

		if not os.path.exists(self.args.modelPath):
			os.makedirs(self.args.modelPath)

		if not os.path.exists(self.summaryDir):
			os.makedirs(self.summaryDir)

		if not os.path.exists(datasetFileName):
			self.textData = TextData(self.args)
			with open(datasetFileName, 'wb') as datasetFile:
				p.dump(self.textData, datasetFile)
			print('dataset created and saved to {}'.format(datasetFileName))
		else:
			with open(datasetFileName, 'rb') as datasetFile:
				self.textData = p.load(datasetFile)
			print('dataset loaded from {}'.format(datasetFileName))


		self.modelPath = os.path.join(self.args.modelPath, self.args.dataset)
		self.model_path = utils.constructFileName(self.args, prefix=self.modelPath, tag='model')
		if not os.path.exists(self.model_path):
			os.makedirs(self.model_path)
		self.model_name = os.path.join(self.model_path, 'model')

		# summary writer
		self.summaryDir = utils.constructFileName(self.args, prefix=self.summaryDir)

		tf.enable_eager_execution()
		self.model = Model(self.args, self.textData)

		for e in range(self.args.epochs):
			# training
			trainBatches = self.textData.train_batches
			#trainBatches = self.textData.get_batches(tag='train')
			totalTrainLoss = 0.0

			# cnt of batches
			cnt = 0

			total_samples = 0
			total_corrects = 0
			all_skip_rate = []
			for idx, nextBatch in enumerate(tqdm(trainBatches)):
				cnt += 1
				#nextBatch = trainBatches[227]
				self.globalStep += 1

				total_samples += nextBatch.batch_size

				self.model.step(nextBatch, test=False)

				_, loss, predictions, corrects, skip_rate, skip_flag = self.model.buildNetwork()

				loss = self.to_numpy(loss)
				predictions = self.to_numpy(predictions)
				corrects = self.to_numpy(corrects)
				skip_rate = self.to_numpy(skip_rate)
				skip_flag = self.to_numpy(skip_flag)

				# skip_rate: batch_size * n_samples
				all_skip_rate.extend(skip_rate.tolist())
				#print(loss, idx)
				total_corrects += corrects
				totalTrainLoss += loss


			trainAcc = total_corrects * 1.0 / (total_samples*self.args.nSamples)
			train_skip_rate = np.average(all_skip_rate)
			print('\nepoch = {}, Train, loss = {}, trainAcc = {}, train_skip_rate = {}'.
			      format(e, totalTrainLoss, trainAcc, train_skip_rate))
Esempio n. 5
0
    def main(self, args=None):
        print('TensorFlow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)

        self.resultDir = os.path.join(self.args.resultDir, self.args.dataset)
        self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset)
        self.dataDir = os.path.join(self.args.dataDir, self.args.dataset)
        self.testDir = os.path.join(self.args.testDir, self.args.dataset)

        self.outFile = utils.constructFileName(self.args,
                                               prefix=self.resultDir)
        self.testFile = utils.constructFileName(self.args, prefix=self.testDir)

        self.args.datasetName = utils.constructFileName(
            self.args, prefix=self.args.dataset, createDataSetName=True)
        datasetFileName = os.path.join(self.dataDir, self.args.datasetName)

        if not os.path.exists(self.resultDir):
            os.makedirs(self.resultDir)

        if not os.path.exists(self.testDir):
            os.makedirs(self.testDir)

        if not os.path.exists(self.args.modelPath):
            os.makedirs(self.args.modelPath)

        if not os.path.exists(self.summaryDir):
            os.makedirs(self.summaryDir)

        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}, exiting ...'.format(
                datasetFileName))
            exit(0)
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        # self.statistics()

        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.model_path = os.path.join(self.args.modelPath, self.args.dataset)
        self.model_path = utils.constructFileName(self.args,
                                                  prefix=self.model_path,
                                                  tag='model')
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args,
                                                  prefix=self.summaryDir)
        if self.args.eager:
            tf.enable_eager_execution(
                config=sessConfig,
                device_policy=tf.contrib.eager.DEVICE_PLACEMENT_WARN)
            print('eager execution enabled')

        # import timeit
        #
        # start = timeit.default_timer()
        #
        # self.textData.get_batches(tag='train', augment=self.args.augment)
        #
        # stop = timeit.default_timer()
        #
        # print('Time: ', stop - start)
        # exit(0)

        with tf.device(self.args.device):
            if self.args.model.find('hbmp') != -1:
                if self.args.model.find('share') != -1:
                    print('Creating model with HBMP share')
                    self.model = ModelHBMPShare(self.args,
                                                self.textData,
                                                eager=self.args.eager)
                else:
                    print('Creating model with HBMP ordinary')
                    self.model = ModelHBMP(self.args,
                                           self.textData,
                                           eager=self.args.eager)
            else:
                self.model = ModelBasic(self.args,
                                        self.textData,
                                        eager=self.args.eager)
                print('Basic model created!')

            if self.args.eager:
                self.train_eager()
                exit(0)
            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir,
                                                       self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                table_init = tf.tables_initializer()
                # initialize all global variables
                self.sess.run([init, table_init])
                print('All variables initialized')

            self.train(self.sess)
Esempio n. 6
0
class Train:
    def __init__(self):
        self.args = None

        self.textData = None
        self.model = None
        self.outFile = None
        self.sess = None
        self.saver = None
        self.model_name = None
        self.model_path = None
        self.globalStep = 0
        self.summaryDir = None
        self.testOutFile = None
        self.summaryWriter = None
        self.mergedSummary = None

    @staticmethod
    def parse_args(args):
        parser = argparse.ArgumentParser()

        parser.add_argument('--resultDir',
                            type=str,
                            default='result',
                            help='result directory')
        parser.add_argument('--testDir', type=str, default='test_result')
        # data location
        dataArgs = parser.add_argument_group('Dataset options')

        dataArgs.add_argument('--summaryDir', type=str, default='summaries')
        dataArgs.add_argument('--datasetName',
                              type=str,
                              default='dataset',
                              help='a TextData object')

        dataArgs.add_argument('--dataDir',
                              type=str,
                              default='data',
                              help='dataset directory, save pkl here')
        dataArgs.add_argument('--dataset', type=str, default='emo')
        dataArgs.add_argument('--trainFile', type=str, default='train.txt')
        dataArgs.add_argument('--valFile', type=str, default='dev.txt')
        dataArgs.add_argument('--testFile', type=str, default='test.txt')

        dataArgs.add_argument('--trainLiwcFile',
                              type=str,
                              default='train_liwc.csv')
        dataArgs.add_argument('--valLiwcFile',
                              type=str,
                              default='dev_liwc.csv')
        dataArgs.add_argument('--testLiwcFile',
                              type=str,
                              default='dev_liwc.csv')

        dataArgs.add_argument('--embeddingFile',
                              type=str,
                              default='glove.840B.300d.txt')
        dataArgs.add_argument('--vocabSize',
                              type=int,
                              default=-1,
                              help='vocab size, use the most frequent words')

        dataArgs.add_argument('--snliDir', type=str, default='snli')
        dataArgs.add_argument('--trainSnliFile',
                              type=str,
                              default='train_snli.txt')
        dataArgs.add_argument('--valSnliFile',
                              type=str,
                              default='dev_snli.txt')
        dataArgs.add_argument('--testSnliFile',
                              type=str,
                              default='test_snli.txt')

        # neural network options
        nnArgs = parser.add_argument_group('Network options')
        nnArgs.add_argument('--embeddingSize', type=int, default=300)
        nnArgs.add_argument('--hiddenSize',
                            type=int,
                            default=300,
                            help='hiddenSize for RNN sentence encoder')
        nnArgs.add_argument(
            '--rnnLayers',
            type=int,
            default=1,
            help='number of RNN layers, fix to 1 in the DCRNN model')
        nnArgs.add_argument('--maxSteps', type=int, default=30)
        nnArgs.add_argument('--emoClasses', type=int, default=4)
        nnArgs.add_argument('--snliClasses', type=int, default=2)
        nnArgs.add_argument('--nTurn', type=int, default=3)
        nnArgs.add_argument('--speakerEmbedSize', type=int, default=0)
        nnArgs.add_argument('--nLSTM',
                            type=int,
                            default=3,
                            help='in DCRNN, this is the ')
        nnArgs.add_argument('--heads', type=int, default=3)
        nnArgs.add_argument('--attn', action='store_true')
        nnArgs.add_argument('--selfattn', action='store_true')
        nnArgs.add_argument('--nContexts', type=int, default=4)
        nnArgs.add_argument('--independent', action='store_true')

        # training options
        trainingArgs = parser.add_argument_group('Training options')
        trainingArgs.add_argument(
            '--model',
            type=str,
            help='hbmp, dcrnn, transfer+hbmp, transfer+dcrnn, hbmp+share')
        trainingArgs.add_argument('--eager',
                                  action='store_true',
                                  help='turn on eager mode for debugging')
        trainingArgs.add_argument('--modelPath', type=str, default='saved')
        trainingArgs.add_argument('--preEmbedding', action='store_true')
        trainingArgs.add_argument('--elmo', action='store_true')
        trainingArgs.add_argument('--trainElmo', action='store_true')
        trainingArgs.add_argument('--dropOut',
                                  type=float,
                                  default=1.0,
                                  help='dropout rate for RNN (keep prob)')
        trainingArgs.add_argument('--learningRate',
                                  type=float,
                                  default=0.001,
                                  help='learning rate')
        trainingArgs.add_argument('--batchSize',
                                  type=int,
                                  default=100,
                                  help='batch size')
        trainingArgs.add_argument('--epochs',
                                  type=int,
                                  default=200,
                                  help='most training epochs')
        trainingArgs.add_argument('--device',
                                  type=str,
                                  default='/gpu:0',
                                  help='use the first GPU as default')
        trainingArgs.add_argument('--loadModel',
                                  action='store_true',
                                  help='whether or not to use old models')
        trainingArgs.add_argument(
            '--sampleWeight',
            default=6.848,
            type=float,
            help='a constant to balance different categories')
        trainingArgs.add_argument(
            '--weighted',
            action='store_true',
            help='whether or not to weight the training samples')
        trainingArgs.add_argument('--ffnn',
                                  type=int,
                                  default=500,
                                  help='intermediate ffnn size')
        trainingArgs.add_argument(
            '--gamma',
            type=float,
            default=0.1,
            help='we use a lambda to balance between snli and emo,'
            'multiply gradients snli samples by this lambda')
        trainingArgs.add_argument(
            '--universal',
            action='store_true',
            help='whether or not to use universal sent embeddings')
        trainingArgs.add_argument(
            '--augment',
            action='store_true',
            help='data augumentation by adding random 2nd sentences')
        """
		in training data: happy, sad, angry: 5k (16.67%) each; others: 15k (50%)
		in dev/test: happy, sad, angry: 4% each; others: 88%
					88/3 = 29.333
					29.33/4 = 7.33

in genuine data:					
					
30160 train, happy = 4243, of 0.14068302387267906, sad = 5463, of 0.1811339522546419, angry = 5506, of 0.18255968169761272, others = 14948, of 0.4956233421750663
2755 val, happy = 142, of 0.051542649727767696, sad = 125, of 0.045372050816696916, angry = 150, of 0.0544464609800363, others = 2338, of 0.8486388384754991
		
		"""
        return parser.parse_args(args)

    def statistics(self):
        """
		27144 train, happy = 3815, of 0.14054671382257589, sad = 4920, of 0.1812555260831123,
		            angry = 4977, of 0.1833554376657825, others = 13432, of 0.4948423224285293

		3016 val, happy = 428, of 0.1419098143236074, sad = 543, of 0.18003978779840848,
					angry = 529, of 0.17539787798408488, others = 1516, of 0.5026525198938993
		:return:
		"""
        train_samples = self.textData.train_samples
        val_samples = self.textData.valid_samples

        def cnt(samples, tag='train'):
            happy = 0
            sad = 0
            angry = 0
            others = 0
            for sample in samples:
                if sample.label == self.textData.label2idx['happy']:
                    happy += 1
                elif sample.label == self.textData.label2idx['sad']:
                    sad += 1
                elif sample.label == self.textData.label2idx['angry']:
                    angry += 1
                else:
                    others += 1
            total = happy + sad + angry + others
            print(
                '{} {}, happy = {}, of {}, sad = {}, of {}, angry = {}, of {}, others = {}, of {}'
                .format(total, tag, happy, happy / total, sad, sad / total,
                        angry, angry / total, others, others / total))

        cnt(train_samples, 'train')
        cnt(val_samples, 'val')
        exit(0)

    def main(self, args=None):
        print('TensorFlow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)

        self.resultDir = os.path.join(self.args.resultDir, self.args.dataset)
        self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset)
        self.dataDir = os.path.join(self.args.dataDir, self.args.dataset)
        self.testDir = os.path.join(self.args.testDir, self.args.dataset)

        self.outFile = utils.constructFileName(self.args,
                                               prefix=self.resultDir)
        self.testFile = utils.constructFileName(self.args, prefix=self.testDir)

        self.args.datasetName = utils.constructFileName(
            self.args, prefix=self.args.dataset, createDataSetName=True)
        datasetFileName = os.path.join(self.dataDir, self.args.datasetName)

        if not os.path.exists(self.resultDir):
            os.makedirs(self.resultDir)

        if not os.path.exists(self.testDir):
            os.makedirs(self.testDir)

        if not os.path.exists(self.args.modelPath):
            os.makedirs(self.args.modelPath)

        if not os.path.exists(self.summaryDir):
            os.makedirs(self.summaryDir)

        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}, exiting ...'.format(
                datasetFileName))
            exit(0)
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        # self.statistics()

        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.model_path = os.path.join(self.args.modelPath, self.args.dataset)
        self.model_path = utils.constructFileName(self.args,
                                                  prefix=self.model_path,
                                                  tag='model')
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args,
                                                  prefix=self.summaryDir)
        if self.args.eager:
            tf.enable_eager_execution(
                config=sessConfig,
                device_policy=tf.contrib.eager.DEVICE_PLACEMENT_WARN)
            print('eager execution enabled')

        # import timeit
        #
        # start = timeit.default_timer()
        #
        # self.textData.get_batches(tag='train', augment=self.args.augment)
        #
        # stop = timeit.default_timer()
        #
        # print('Time: ', stop - start)
        # exit(0)

        with tf.device(self.args.device):
            if self.args.model.find('hbmp') != -1:
                if self.args.model.find('share') != -1:
                    print('Creating model with HBMP share')
                    self.model = ModelHBMPShare(self.args,
                                                self.textData,
                                                eager=self.args.eager)
                else:
                    print('Creating model with HBMP ordinary')
                    self.model = ModelHBMP(self.args,
                                           self.textData,
                                           eager=self.args.eager)
            else:
                self.model = ModelBasic(self.args,
                                        self.textData,
                                        eager=self.args.eager)
                print('Basic model created!')

            if self.args.eager:
                self.train_eager()
                exit(0)
            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir,
                                                       self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                table_init = tf.tables_initializer()
                # initialize all global variables
                self.sess.run([init, table_init])
                print('All variables initialized')

            self.train(self.sess)

    def train_eager(self):
        for e in range(self.args.epochs):
            trainBatches = self.textData.train_batches

            for idx, nextBatch in enumerate(tqdm(trainBatches)):
                self.model.step(nextBatch, test=False, eager=self.args.eager)
                self.model.buildNetwork()

                print()

    def train(self, sess):
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_val_f1_micro_unweighted = 0.0

        for e in range(self.args.epochs):
            # training
            trainBatches = self.textData.get_batches(tag='train',
                                                     augment=self.args.augment)
            #trainBatches = self.textData.train_batches
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_samples = 0
            total_corrects = 0

            all_predictions = []
            all_labels = []
            all_sample_weights = []

            for idx, nextBatch in enumerate(tqdm(trainBatches)):

                cnt += 1
                self.globalStep += 1
                total_samples += nextBatch.batch_size
                # print(idx)

                ops, feed_dict, labels, sample_weights = self.model.step(
                    nextBatch, test=False)
                _, loss, predictions, corrects = sess.run(ops, feed_dict)
                all_predictions.extend(predictions)
                all_labels.extend(labels)
                all_sample_weights.extend(sample_weights)
                total_corrects += corrects
                totalTrainLoss += loss

                self.summaryWriter.add_summary(
                    utils.makeSummary({"train_loss": loss}), self.globalStep)
                #break
            trainAcc = total_corrects * 1.0 / total_samples

            # calculate f1 score for train (weighted/unweighted)
            train_f1_micro, train_f1_macro, train_p_micro, train_r_micro, train_p_macro, train_r_macro\
             = self.cal_F1(y_pred=all_predictions, y_true=all_labels)
            train_f1_micro_w, train_f1_macro_w, train_p_micro_w, train_r_micro_w, train_p_macro_w, train_r_macro_w\
             = self.cal_F1(y_pred=all_predictions, y_true=all_labels,
                                                             sample_weight=all_sample_weights)

            print(
                '\nepoch = {}, Train, loss = {}, trainAcc = {}, train_f1_micro = {}, train_f1_macro = {},'
                ' train_f1_micro_w = {}, train_f1_macro_w = {}'.format(
                    e, totalTrainLoss, trainAcc, train_f1_micro,
                    train_f1_macro, train_f1_micro_w, train_f1_macro_w))
            print(
                '\ttrain_p_micro = {}, train_r_micro = {}, train_p_macro = {}, train_r_macro = {}'
                .format(train_p_micro, train_r_micro, train_p_macro,
                        train_r_macro))
            print(
                '\ttrain_p_micro_w = {}, train_r_micro_w = {}, train_p_macro_w = {}, train_r_macro_w = {}'
                .format(train_p_micro_w, train_r_micro_w, train_p_macro_w,
                        train_r_macro_w))

            out.write(
                '\nepoch = {}, loss = {}, trainAcc = {}, train_f1_micro = {}, train_f1_macro = {},'
                ' train_f1_micro_w = {}, train_f1_macro_w = {}\n'.format(
                    e, totalTrainLoss, trainAcc, train_f1_micro,
                    train_f1_macro, train_f1_micro_w, train_f1_macro_w))
            out.write(
                '\ttrain_p_micro = {}, train_r_micro = {}, train_p_macro = {}, train_r_macro = {}\n'
                .format(train_p_micro, train_r_micro, train_p_macro,
                        train_r_macro))
            out.write(
                '\ttrain_p_micro_w = {}, train_r_micro_w = {}, train_p_macro_w = {}, train_r_macro_w = {}\n'
                .format(train_p_micro_w, train_r_micro_w, train_p_macro_w,
                        train_r_macro_w))
            #continue

            out.flush()

            # calculate f1 score for val (weighted/unweighted)
            valAcc, valLoss, val_f1_micro, val_f1_macro, val_f1_micro_w, val_f1_macro_w,\
            val_p_micro, val_r_micro, val_p_macro, val_r_macro, val_p_micro_w, val_r_micro_w, val_p_macro_w, val_r_macro_w\
             = self.test(sess, tag='val')

            print(
                '\n\tVal, loss = {}, valAcc = {}, val_f1_micro = {}, val_f1_macro = {}, val_f1_micro_w = {}, val_f1_macro_w = {}'
                .format(valLoss, valAcc, val_f1_micro, val_f1_macro,
                        val_f1_micro_w, val_f1_macro_w))
            print(
                '\t\t val_p_micro = {}, val_r_micro = {}, val_p_macro = {}, val_r_macro = {}'
                .format(val_p_micro, val_r_micro, val_p_macro, val_r_macro))
            print(
                '\t\t val_p_micro_w = {}, val_r_micro_w = {}, val_p_macro_w = {}, val_r_macro_w = {}'
                .format(val_p_micro_w, val_r_micro_w, val_p_macro_w,
                        val_r_macro_w))

            out.write(
                '\n\tVal, loss = {}, valAcc = {}, val_f1_micro = {}, val_f1_macro = {}, val_f1_micro_w = {}, val_f1_macro_w = {}\n'
                .format(valLoss, valAcc, val_f1_micro, val_f1_macro,
                        val_f1_micro_w, val_f1_macro_w))
            out.write(
                '\t\t val_p_micro = {}, val_r_micro = {}, val_p_macro = {}, val_r_macro = {}\n'
                .format(val_p_micro, val_r_micro, val_p_macro, val_r_macro))
            out.write(
                '\t\t val_p_micro_w = {}, val_r_micro_w = {}, val_p_macro_w = {}, val_r_macro_w = {}\n'
                .format(val_p_micro_w, val_r_micro_w, val_p_macro_w,
                        val_r_macro_w))

            # calculate f1 score for test (weighted/unweighted)
            testAcc, testLoss, test_f1_micro, test_f1_macro, test_f1_micro_w, test_f1_macro_w,\
            test_p_micro, test_r_micro, test_p_macro, test_r_macro, test_p_micro_w, test_r_micro_w, test_p_macro_w, test_r_macro_w\
             = self.test(sess, tag='test')

            print(
                '\n\ttest, loss = {}, testAcc = {}, test_f1_micro = {}, test_f1_macro = {}, test_f1_micro_w = {}, test_f1_macro_w = {}'
                .format(testLoss, testAcc, test_f1_micro, test_f1_macro,
                        test_f1_micro_w, test_f1_macro_w))
            print(
                '\t\t test_p_micro = {}, test_r_micro = {}, test_p_macro = {}, test_r_macro = {}'
                .format(test_p_micro, test_r_micro, test_p_macro,
                        test_r_macro))
            print(
                '\t\t test_p_micro_w = {}, test_r_micro_w = {}, test_p_macro_w = {}, test_r_macro_w = {}'
                .format(test_p_micro_w, test_r_micro_w, test_p_macro_w,
                        test_r_macro_w))

            out.write(
                '\n\ttest, loss = {}, testAcc = {}, test_f1_micro = {}, test_f1_macro = {}, test_f1_micro_w = {}, test_f1_macro_w = {}\n'
                .format(testLoss, testAcc, test_f1_micro, test_f1_macro,
                        test_f1_micro_w, test_f1_macro_w))
            out.write(
                '\t\t test_p_micro = {}, test_r_micro = {}, test_p_macro = {}, test_r_macro = {}\n'
                .format(test_p_micro, test_r_micro, test_p_macro,
                        test_r_macro))
            out.write(
                '\t\t test_p_micro_w = {}, test_r_micro_w = {}, test_p_macro_w = {}, test_r_macro_w = {}\n'
                .format(test_p_micro_w, test_r_micro_w, test_p_macro_w,
                        test_r_macro_w))

            out.flush()

            self.summaryWriter.add_summary(
                utils.makeSummary({"train_acc": trainAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"val_acc": valAcc}), e)

            # use val_f1_micro (unweighted) as metric
            if val_f1_micro >= current_val_f1_micro_unweighted:
                current_val_f1_micro_unweighted = val_f1_micro
                print('New val f1_micro {} at epoch {}'.format(
                    val_f1_micro, e))
                out.write('New val f1_micro {} at epoch {}\n'.format(
                    val_f1_micro, e))

                save_path = self.saver.save(sess, save_path=self.model_name)
                print('model saved at {}'.format(save_path))
                out.write('model saved at {}\n'.format(save_path))

                test_predictions = self.test(sess, tag='test2')
                print('Writing predictions at epoch {}'.format(e))
                out.write('Writing predictions at epoch {}\n'.format(e))
                test_file = self.write_predictions(test_predictions,
                                                   tag='unweighted')

                print('Writing predictions to {}'.format(test_file))
                out.write('Writing predictions to {}\n'.format(test_file))

            out.flush()
        out.close()

    def write_predictions(self, predictions, tag='weighted'):
        test_file = self.testFile + '_' + tag
        with open(test_file, 'w') as file:
            file.write('id\tturn1\tturn2\tturn3\tlabel\n')
            idx2label = {v: k for k, v in self.textData.label2idx.items()}
            for idx, sample in enumerate(self.textData.test_samples):
                assert idx == sample.id

                file.write(str(idx) + '\t')
                for ind, sent in enumerate(sample.sents):
                    file.write(' '.join(sent[:sample.length[ind]]).encode(
                        'ascii', 'ignore').decode('ascii') + '\t')
                file.write(idx2label[predictions[idx]] + '\n')
        return test_file

    def test(self, sess, tag='val'):
        """
		for the real dev data, during test, do not use sample weights
		:param sess:
		:param tag:
		:return:
		"""
        if tag == 'val':
            print('Validating\n')
            batches = self.textData.val_batches
        else:
            print('Testing\n')
            batches = self.textData.test_batches

        cnt = 0

        total_samples = 0
        total_corrects = 0
        total_loss = 0.0
        all_predictions = []
        all_labels = []
        all_sample_weights = []
        for idx, nextBatch in enumerate(tqdm(batches)):
            cnt += 1

            total_samples += nextBatch.batch_size
            ops, feed_dict, labels, sample_weights = self.model.step(nextBatch,
                                                                     test=True)

            loss, predictions, corrects = sess.run(ops, feed_dict)
            all_predictions.extend(predictions)
            all_labels.extend(labels)
            all_sample_weights.extend(sample_weights)
            total_loss += loss
            total_corrects += corrects

            #break

        f1_micro, f1_macro, p_micro, r_micro, p_macro, r_macro = self.cal_F1(
            y_pred=all_predictions, y_true=all_labels)
        f1_micro_w, f1_macro_w, p_micro_w, r_micro_w, p_macro_w, r_macro_w =\
         self.cal_F1(y_pred=all_predictions, y_true=all_labels, sample_weight=all_sample_weights)

        acc = total_corrects * 1.0 / total_samples

        if tag == 'test2':
            return all_predictions
        else:
            return acc, total_loss, f1_micro, f1_macro, f1_micro_w, f1_macro_w,\
                   p_micro, r_micro, p_macro, r_macro,\
                   p_micro_w, r_micro_w, p_macro_w, r_macro_w

    def cal_F1(self, y_pred, y_true, sample_weight=None):
        labels = [
            self.textData.label2idx['happy'], self.textData.label2idx['sad'],
            self.textData.label2idx['angry']
        ]
        # if sample_weight is not None:
        # 	sample_weight = np.asarray(sample_weight)
        # 	sample_weight = sample_weight.astype(int)
        p_micro, r_micro, f1_micro, _ = \
         precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average='micro', labels=labels, sample_weight=sample_weight)
        p_macro, r_macro, f1_macro, _ = \
         precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average='macro', labels=labels, sample_weight=sample_weight)

        return f1_micro, f1_macro, p_micro, r_micro, p_macro, r_macro
class Train:
    def __init__(self):
        self.args = None

        self.textData = None
        self.model = None
        self.outFile = None
        self.sess = None
        self.saver = None
        self.model_name = None
        self.model_path = None
        self.globalStep = 0
        self.summaryDir = None
        self.testOutFile = None
        self.summaryWriter = None
        self.mergedSummary = None

    @staticmethod
    def parse_args(args):
        parser = argparse.ArgumentParser()

        parser.add_argument('--resultDir',
                            type=str,
                            default='result',
                            help='result directory')
        parser.add_argument('--testDir', type=str, default='test_result')
        # data location
        dataArgs = parser.add_argument_group('Dataset options')

        dataArgs.add_argument('--summaryDir', type=str, default='summaries')
        dataArgs.add_argument('--datasetName',
                              type=str,
                              default='dataset',
                              help='a TextData object')

        dataArgs.add_argument('--dataDir',
                              type=str,
                              default='data',
                              help='dataset directory, save pkl here')
        dataArgs.add_argument('--dataset', type=str, default='rotten')
        dataArgs.add_argument('--trainFile', type=str, default='train.txt')
        dataArgs.add_argument('--valFile', type=str, default='val.txt')
        dataArgs.add_argument('--testFile', type=str, default='test.txt')
        dataArgs.add_argument('--embeddingFile',
                              type=str,
                              default='glove.840B.300d.txt')
        dataArgs.add_argument('--vocabSize',
                              type=int,
                              default=-1,
                              help='vocab size, use the most frequent words')

        # neural network options
        nnArgs = parser.add_argument_group('Network options')
        nnArgs.add_argument('--embeddingSize', type=int, default=300)
        nnArgs.add_argument('--hiddenSize',
                            type=int,
                            default=200,
                            help='hiddenSize for RNN sentence encoder')
        nnArgs.add_argument('--rnnLayers', type=int, default=1)
        nnArgs.add_argument('--maxSteps', type=int, default=50)
        nnArgs.add_argument('--numClasses', type=int, default=2)
        nnArgs.add_argument('--skim', action='store_true')
        # training options
        trainingArgs = parser.add_argument_group('Training options')
        trainingArgs.add_argument('--modelPath', type=str, default='saved')
        trainingArgs.add_argument('--preEmbedding', action='store_true')
        trainingArgs.add_argument('--dropOut',
                                  type=float,
                                  default=1.0,
                                  help='dropout rate for RNN (keep prob)')
        trainingArgs.add_argument('--learningRate',
                                  type=float,
                                  default=0.001,
                                  help='learning rate')
        trainingArgs.add_argument('--batchSize',
                                  type=int,
                                  default=32,
                                  help='batch size')
        trainingArgs.add_argument('--skimloss',
                                  action='store_true',
                                  help='whether or not to encourage skimming')
        trainingArgs.add_argument('--minRead', type=int, default=2)
        trainingArgs.add_argument('--maxSkip', type=int, default=5)
        trainingArgs.add_argument('--discount', type=float, default=0.99)
        # max_grad_norm
        ## do not add dropOut in the test mode!
        trainingArgs.add_argument('--epochs',
                                  type=int,
                                  default=40,
                                  help='most training epochs')
        trainingArgs.add_argument('--device',
                                  type=str,
                                  default='/gpu:0',
                                  help='use the first GPU as default')
        trainingArgs.add_argument('--loadModel',
                                  action='store_true',
                                  help='whether or not to use old models')
        trainingArgs.add_argument('--testModel', action='store_true')
        trainingArgs.add_argument('--printgate', action='store_true')
        trainingArgs.add_argument('--nSamples', type=int, default=3)
        trainingArgs.add_argument('--eps', type=float, default=0.1)
        return parser.parse_args(args)

    def main(self, args=None):
        print('TensorFlow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)

        self.resultDir = os.path.join(self.args.resultDir, self.args.dataset)
        self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset)
        self.dataDir = os.path.join(self.args.dataDir, self.args.dataset)

        self.outFile = utils.constructFileName(self.args,
                                               prefix=self.resultDir)
        self.args.datasetName = utils.constructFileName(
            self.args, prefix=self.args.dataset, createDataSetName=True)
        datasetFileName = os.path.join(self.dataDir, self.args.datasetName)

        if not os.path.exists(self.resultDir):
            os.makedirs(self.resultDir)

        if not os.path.exists(self.args.modelPath):
            os.makedirs(self.args.modelPath)

        if not os.path.exists(self.summaryDir):
            os.makedirs(self.summaryDir)

        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}'.format(datasetFileName))
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.modelPath = os.path.join(self.args.modelPath, self.args.dataset)
        self.model_path = utils.constructFileName(self.args,
                                                  prefix=self.modelPath,
                                                  tag='model')
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args,
                                                  prefix=self.summaryDir)

        with tf.device(self.args.device):
            if self.args.skim:
                print('Skim model created')
                self.model = Model(self.args, self.textData)
            else:
                print('Ordinary model created!')
                self.model = ModelBasic(self.args, self.textData)

            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir,
                                                       self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                # initialize all global variables
                self.sess.run(init)
                print('All variables initialized')

            if not self.args.testModel:
                self.train(self.sess)
            else:
                self.testModel(self.sess)

    def train(self, sess):
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_valAcc = 0.0

        for e in range(self.args.epochs):
            # training
            #trainBatches = self.textData.train_batches
            trainBatches = self.textData.get_batches(tag='train')
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_samples = 0
            total_corrects = 0
            all_skip_rate = []
            for idx, nextBatch in enumerate(tqdm(trainBatches)):
                cnt += 1
                #nextBatch = trainBatches[227]
                self.globalStep += 1

                total_samples += nextBatch.batch_size

                py = tf.contrib.eager.py_func(self.model.buildNetwork())

                ops, feed_dict, length = self.model.step(nextBatch, test=False)

                sess.run(py, feed_dict=feed_dict)

                # skip_rate: batch_size * n_samples
                _, loss, predictions, corrects, skip_rate = sess.run(
                    ops, feed_dict)
                all_skip_rate.extend(skip_rate.tolist())
                #print(loss, idx)
                total_corrects += corrects
                totalTrainLoss += loss

                self.summaryWriter.add_summary(
                    utils.makeSummary({"train_loss": loss}), self.globalStep)

            trainAcc = total_corrects * 1.0 / (total_samples *
                                               self.args.nSamples)
            train_skip_rate = np.average(all_skip_rate)
            print(
                '\nepoch = {}, Train, loss = {}, trainAcc = {}, train_skip_rate = {}'
                .format(e, totalTrainLoss, trainAcc, train_skip_rate))

            #continue
            out.write(
                '\nepoch = {}, loss = {}, trainAcc = {}, train_skip_rate = {}\n'
                .format(e, totalTrainLoss, trainAcc, train_skip_rate))
            out.flush()

            valAcc, valLoss, val_skip_rate = self.test(sess, tag='val')
            testAcc, testLoss, test_skip_rate = self.test(sess, tag='test')

            print('\tVal, loss = {}, valAcc = {}, val_skip_rate = {}'.format(
                valLoss, valAcc, val_skip_rate))
            out.write(
                '\tVal, loss = {}, valAcc = {}, val_skip_rate = {}\n'.format(
                    valLoss, valAcc, val_skip_rate))

            print(
                '\tTest, loss = {}, testAcc = {}, test_skip_rate = {}'.format(
                    testLoss, testAcc, test_skip_rate))
            out.write('\tTest, loss = {}, testAcc = {}, test_skip_rate = {}\n'.
                      format(testLoss, testAcc, test_skip_rate))

            self.summaryWriter.add_summary(
                utils.makeSummary({"train_acc": trainAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"val_acc": valAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"test_acc": testAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"train_skip_rate": train_skip_rate}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"val_skip_rate": val_skip_rate}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"test_skip_rate": test_skip_rate}), e)
            # we do not use cross val currently, just train, then evaluate
            if valAcc >= current_valAcc:
                # with open('skip_train.pkl', 'wb') as f:
                # 	p.dump(all_skip_rate, f)
                current_valAcc = valAcc
                print('New valAcc {} at epoch {}'.format(valAcc, e))
                out.write('New valAcc {} at epoch {}\n'.format(valAcc, e))
                save_path = self.saver.save(sess, save_path=self.model_name)
                print('model saved at {}'.format(save_path))
                out.write('model saved at {}\n'.format(save_path))

            out.flush()
        out.close()

    def test(self, sess, tag='val'):
        if tag == 'val':
            print('Validating\n')
            batches = self.textData.val_batches
        else:
            print('Testing\n')
            batches = self.textData.test_batches

        cnt = 0

        total_samples = 0
        total_corrects = 0
        total_loss = 0.0
        all_predictions = []
        all_skip_rate = []
        for idx, nextBatch in enumerate(tqdm(batches)):
            cnt += 1

            total_samples += nextBatch.batch_size
            ops, feed_dict, length = self.model.step(nextBatch, test=True)

            loss, predictions, corrects, skip_rate = sess.run(ops, feed_dict)
            all_skip_rate.extend(skip_rate)
            all_predictions.extend(predictions)
            total_loss += loss
            total_corrects += corrects

            total_length = np.sum(length)

        # plt.hist(all_skip_rate)
        # plt.savefig('tmp.png')
        # print(np.average(all_skip_rate))

        acc = total_corrects * 1.0 / total_samples
        return acc, total_loss, np.average(all_skip_rate)

    def testModel(self, sess):
        acc, total_loss, _ = self.test(sess, tag='test')
        print('acc = {}, total_loss = {}'.format(acc, total_loss))
    def main(self, args=None):
        print('TensorFlow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)

        self.process_gates_files()

        self.resultDir = os.path.join(self.args.resultDir, self.args.dataset)
        self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset)
        self.dataDir = os.path.join(self.args.dataDir, self.args.dataset)

        self.outFile = utils.constructFileName(self.args,
                                               prefix=self.resultDir)
        self.args.datasetName = utils.constructFileName(
            self.args, prefix=self.args.dataset, createDataSetName=True)
        if self.args.next:
            self.args.datasetName = 'next_' + self.args.datasetName
        else:
            self.args.datasetName = 'current_' + self.args.datasetName
        datasetFileName = os.path.join(self.dataDir, self.args.datasetName)

        if not os.path.exists(self.resultDir):
            os.makedirs(self.resultDir)

        if not os.path.exists(self.args.modelPath):
            os.makedirs(self.args.modelPath)

        if not os.path.exists(self.summaryDir):
            os.makedirs(self.summaryDir)

        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}'.format(datasetFileName))
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        # init hard gates
        self.init_hard_gates(self.args.threshold, self.args.percent)
        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.modelPath = os.path.join(self.args.modelPath, self.args.dataset)
        self.model_path = utils.constructFileName(self.args,
                                                  prefix=self.modelPath,
                                                  tag='model')
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args,
                                                  prefix=self.summaryDir)

        with tf.device(self.args.device):
            if self.args.skim:
                print('Skim model created')
                self.model = Model(self.args, self.textData)
            else:
                print('Ordinary model created!')
                self.model = ModelBasic(self.args, self.textData)

            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir,
                                                       self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                # initialize all global variables
                self.sess.run(init)
                print('All variables initialized')

            if not self.args.testModel:
                self.train(self.sess)
            else:
                self.testModel(self.sess)
class Train:
    def __init__(self):
        self.args = None

        self.textData = None
        self.model = None
        self.outFile = None
        self.sess = None
        self.saver = None
        self.model_name = None
        self.model_path = None
        self.globalStep = 0
        self.summaryDir = None
        self.testOutFile = None
        self.summaryWriter = None
        self.mergedSummary = None

    @staticmethod
    def parse_args(args):
        parser = argparse.ArgumentParser()

        parser.add_argument('--resultDir',
                            type=str,
                            default='result',
                            help='result directory')
        parser.add_argument('--testDir', type=str, default='test_result')
        # data location
        dataArgs = parser.add_argument_group('Dataset options')

        dataArgs.add_argument('--summaryDir', type=str, default='summaries')
        dataArgs.add_argument('--datasetName',
                              type=str,
                              default='dataset',
                              help='a TextData object')

        dataArgs.add_argument('--dataDir',
                              type=str,
                              default='data',
                              help='dataset directory, save pkl here')
        dataArgs.add_argument('--dataset', type=str, default='rotten')
        dataArgs.add_argument('--trainFile', type=str, default='train.txt')
        dataArgs.add_argument('--valFile', type=str, default='val.txt')
        dataArgs.add_argument('--testFile', type=str, default='test.txt')
        dataArgs.add_argument('--embeddingFile',
                              type=str,
                              default='glove.840B.300d.txt')
        dataArgs.add_argument('--vocabSize',
                              type=int,
                              default=-1,
                              help='vocab size, use the most frequent words')
        dataArgs.add_argument('--gatesFileTrain',
                              type=str,
                              default='gates_train_')
        dataArgs.add_argument('--gatesFileVal', type=str, default='gates_val_')
        dataArgs.add_argument('--gatesFileTest',
                              type=str,
                              default='gates_test_')
        # neural network options
        nnArgs = parser.add_argument_group('Network options')
        nnArgs.add_argument('--embeddingSize', type=int, default=300)
        nnArgs.add_argument('--hiddenSize',
                            type=int,
                            default=200,
                            help='hiddenSize for RNN sentence encoder')
        nnArgs.add_argument('--rnnLayers', type=int, default=1)
        nnArgs.add_argument('--maxSteps', type=int, default=50)
        nnArgs.add_argument('--numClasses', type=int, default=2)
        nnArgs.add_argument('--skim', action='store_true')
        # training options
        trainingArgs = parser.add_argument_group('Training options')
        trainingArgs.add_argument('--modelPath', type=str, default='saved')
        trainingArgs.add_argument('--preEmbedding', action='store_true')
        trainingArgs.add_argument('--dropOut',
                                  type=float,
                                  default=1.0,
                                  help='dropout rate for RNN (keep prob)')
        trainingArgs.add_argument('--learningRate',
                                  type=float,
                                  default=0.001,
                                  help='learning rate')
        trainingArgs.add_argument('--batchSize',
                                  type=int,
                                  default=32,
                                  help='batch size')
        trainingArgs.add_argument('--skimloss',
                                  action='store_true',
                                  help='whether or not to encourage skimming')
        trainingArgs.add_argument(
            '--minRead',
            type=int,
            default=8,
            help='minimum number of tokens read before a jump')
        trainingArgs.add_argument('--maxSkip',
                                  type=int,
                                  default=5,
                                  help='maximum of jumping steps in a jump')
        trainingArgs.add_argument(
            '--maxJump',
            type=int,
            default=-1,
            help=
            'maximum number of jumps in a sequence, -1 indicates we are not using acl style'
        )
        trainingArgs.add_argument('--discount', type=float, default=0.99)
        trainingArgs.add_argument('--epochs',
                                  type=int,
                                  default=300,
                                  help='most training epochs')
        trainingArgs.add_argument('--device',
                                  type=str,
                                  default='/gpu:0',
                                  help='use the first GPU as default')
        trainingArgs.add_argument('--loadModel',
                                  action='store_true',
                                  help='whether or not to use old models')
        trainingArgs.add_argument('--testModel', action='store_true')
        trainingArgs.add_argument('--printgate', action='store_true')
        trainingArgs.add_argument('--nSamples', type=int, default=3)
        trainingArgs.add_argument('--eps', type=float, default=0.1)
        trainingArgs.add_argument('--random', action='store_true')
        trainingArgs.add_argument('--sparse',
                                  type=float,
                                  default=10.0,
                                  help='coefficient for sparse penalty')
        trainingArgs.add_argument(
            '--percent',
            action='store_true',
            help='whether or not use percentage for hardgate init')
        trainingArgs.add_argument('--threshold',
                                  type=float,
                                  default=0.5,
                                  help='for hardgate init')
        trainingArgs.add_argument('--transferEpochs',
                                  type=int,
                                  default=150,
                                  help='number of epochs for transfering')
        trainingArgs.add_argument('--next',
                                  action='store_true',
                                  help='style of predicting gates')
        trainingArgs.add_argument(
            '--all',
            action='store_true',
            help='whether or not to transfer for val&test')
        return parser.parse_args(args)

    def process_gates_files(self):
        if self.args.next:
            self.args.gatesFileTrain = self.args.gatesFileTrain + 'next.csv'
            self.args.gatesFileVal = self.args.gatesFileVal + 'next.csv'
            self.args.gatesFileTest = self.args.gatesFileTest + 'next.csv'
        else:
            self.args.gatesFileTrain = self.args.gatesFileTrain + 'current.csv'
            self.args.gatesFileVal = self.args.gatesFileVal + 'current.csv'
            self.args.gatesFileTest = self.args.gatesFileTest + 'current.csv'

    def main(self, args=None):
        print('TensorFlow version {}'.format(tf.VERSION))

        # initialize args
        self.args = self.parse_args(args)

        self.process_gates_files()

        self.resultDir = os.path.join(self.args.resultDir, self.args.dataset)
        self.summaryDir = os.path.join(self.args.summaryDir, self.args.dataset)
        self.dataDir = os.path.join(self.args.dataDir, self.args.dataset)

        self.outFile = utils.constructFileName(self.args,
                                               prefix=self.resultDir)
        self.args.datasetName = utils.constructFileName(
            self.args, prefix=self.args.dataset, createDataSetName=True)
        if self.args.next:
            self.args.datasetName = 'next_' + self.args.datasetName
        else:
            self.args.datasetName = 'current_' + self.args.datasetName
        datasetFileName = os.path.join(self.dataDir, self.args.datasetName)

        if not os.path.exists(self.resultDir):
            os.makedirs(self.resultDir)

        if not os.path.exists(self.args.modelPath):
            os.makedirs(self.args.modelPath)

        if not os.path.exists(self.summaryDir):
            os.makedirs(self.summaryDir)

        if not os.path.exists(datasetFileName):
            self.textData = TextData(self.args)
            with open(datasetFileName, 'wb') as datasetFile:
                p.dump(self.textData, datasetFile)
            print('dataset created and saved to {}'.format(datasetFileName))
        else:
            with open(datasetFileName, 'rb') as datasetFile:
                self.textData = p.load(datasetFile)
            print('dataset loaded from {}'.format(datasetFileName))

        # init hard gates
        self.init_hard_gates(self.args.threshold, self.args.percent)
        sessConfig = tf.ConfigProto(allow_soft_placement=True)
        sessConfig.gpu_options.allow_growth = True

        self.modelPath = os.path.join(self.args.modelPath, self.args.dataset)
        self.model_path = utils.constructFileName(self.args,
                                                  prefix=self.modelPath,
                                                  tag='model')
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self.model_name = os.path.join(self.model_path, 'model')

        self.sess = tf.Session(config=sessConfig)
        # summary writer
        self.summaryDir = utils.constructFileName(self.args,
                                                  prefix=self.summaryDir)

        with tf.device(self.args.device):
            if self.args.skim:
                print('Skim model created')
                self.model = Model(self.args, self.textData)
            else:
                print('Ordinary model created!')
                self.model = ModelBasic(self.args, self.textData)

            # saver can only be created after we have the model
            self.saver = tf.train.Saver()

            self.summaryWriter = tf.summary.FileWriter(self.summaryDir,
                                                       self.sess.graph)
            self.mergedSummary = tf.summary.merge_all()

            if self.args.loadModel:
                # load model from disk
                if not os.path.exists(self.model_path):
                    print('model does not exist on disk!')
                    print(self.model_path)
                    exit(-1)

                self.saver.restore(sess=self.sess, save_path=self.model_name)
                print('Variables loaded from disk {}'.format(self.model_name))
            else:
                init = tf.global_variables_initializer()
                # initialize all global variables
                self.sess.run(init)
                print('All variables initialized')

            if not self.args.testModel:
                self.train(self.sess)
            else:
                self.testModel(self.sess)

    def train(self, sess):
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_valAcc = 0.0

        for e in range(self.args.epochs):
            # training
            #trainBatches = self.textData.train_batches
            trainBatches = self.textData.get_batches(tag='train')
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_samples = 0
            total_corrects = 0
            all_skip_rate = []
            all_correct_predicted_inference_skips = []
            total_valids = 0
            if e == self.args.transferEpochs + 1:
                print('RL begins!')
                out.write('RL begins\n')

            for idx, nextBatch in enumerate(tqdm(trainBatches)):
                cnt += 1
                #nextBatch = trainBatches[227]
                self.globalStep += 1

                total_samples += nextBatch.batch_size
                if e > self.args.transferEpochs + 1:
                    ops, feed_dict, length = self.model.step(
                        nextBatch, test=False, is_transfering=False)
                else:
                    # only transfer at first xxx epochs
                    ops, feed_dict, length = self.model.step(
                        nextBatch, test=False, is_transfering=True)
                # skip_rate: batch_size * n_samples
                _, loss, predictions, corrects, skip_rate, correct_predicted_inference_skips, n_valids_sum = sess.run(
                    ops, feed_dict)
                all_correct_predicted_inference_skips.extend(
                    correct_predicted_inference_skips.tolist())
                all_skip_rate.extend(skip_rate.tolist())
                total_valids += n_valids_sum
                #print(loss, idx)
                total_corrects += corrects
                totalTrainLoss += loss

                self.summaryWriter.add_summary(
                    utils.makeSummary({"train_loss": loss}), self.globalStep)
            trainAcc = total_corrects * 1.0 / (total_samples *
                                               self.args.nSamples)
            train_skip_rate = np.average(all_skip_rate)
            train_skip_acc = np.sum(
                all_correct_predicted_inference_skips) / total_valids

            print(
                '\nepoch = {}, Train, loss = {}, trainAcc = {}, train_skip_rate = {}, train_skip_acc = {}'
                .format(e, totalTrainLoss, trainAcc, train_skip_rate,
                        train_skip_acc))
            out.write(
                '\nepoch = {}, loss = {}, trainAcc = {}, train_skip_rate = {}, train_skip_acc = {}\n'
                .format(e, totalTrainLoss, trainAcc, train_skip_rate,
                        train_skip_acc))
            out.flush()
            valAcc, valLoss, val_skip_rate, val_skip_acc = self.test(sess,
                                                                     tag='val')
            testAcc, testLoss, test_skip_rate, test_skip_acc = self.test(
                sess, tag='test')

            print(
                '\tVal, loss = {}, valAcc = {}, val_skip_rate = {}, val_skip_acc = {}'
                .format(valLoss, valAcc, val_skip_rate, val_skip_acc))
            out.write(
                '\tVal, loss = {}, valAcc = {}, val_skip_rate = {}, val_skip_acc = {}\n'
                .format(valLoss, valAcc, val_skip_rate, val_skip_acc))

            print(
                '\tTest, loss = {}, testAcc = {}, test_skip_rate = {}, test_skip_acc = {}'
                .format(testLoss, testAcc, test_skip_rate, test_skip_acc))
            out.write(
                '\tTest, loss = {}, testAcc = {}, test_skip_rate = {}, test_skip_acc = {}\n'
                .format(testLoss, testAcc, test_skip_rate, test_skip_acc))

            self.summaryWriter.add_summary(
                utils.makeSummary({"train_acc": trainAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"val_acc": valAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"test_acc": testAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"train_skip_rate": train_skip_rate}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"val_skip_rate": val_skip_rate}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"test_skip_rate": test_skip_rate}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"train_skip_acc": train_skip_acc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"val_skip_acc": val_skip_acc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"test_skip_acc": test_skip_acc}), e)
            # we do not use cross val currently, just train, then evaluate
            if e == self.args.transferEpochs + 1:
                current_valAcc = 0
            if valAcc >= current_valAcc:
                # with open('skip_train.pkl', 'wb') as f:
                # 	p.dump(all_skip_rate, f)
                current_valAcc = valAcc
                if e < self.args.transferEpochs:
                    print('New valAcc {} at epoch {}'.format(valAcc, e))
                    out.write('New valAcc {} at epoch {}\n'.format(valAcc, e))
                else:
                    print('New valAcc2 {} at epoch {}'.format(valAcc, e))
                    out.write('New valAcc2 {} at epoch {}\n'.format(valAcc, e))
                # save_path = self.saver.save(sess, save_path=self.model_name)
                # print('model saved at {}'.format(save_path))
                # out.write('model saved at {}\n'.format(save_path))

            out.flush()
        out.close()

    def test(self, sess, tag='val'):
        if tag == 'val':
            print('Validating\n')
            batches = self.textData.val_batches
        else:
            print('Testing\n')
            batches = self.textData.test_batches

        cnt = 0

        total_samples = 0
        total_corrects = 0
        total_loss = 0.0
        all_predictions = []
        all_skip_rate = []
        total_valids = 0
        all_correct_predicted_inference_skips = []

        for idx, nextBatch in enumerate(tqdm(batches)):
            cnt += 1

            total_samples += nextBatch.batch_size
            if self.args.all:
                is_transfering = True
            else:
                is_transfering = False
            ops, feed_dict, length = self.model.step(
                nextBatch, test=True, is_transfering=is_transfering)

            loss, predictions, corrects, skip_rate, correct_predicted_inference_skips, n_valids_sum = sess.run(
                ops, feed_dict)
            all_correct_predicted_inference_skips.extend(
                correct_predicted_inference_skips.tolist())
            all_skip_rate.extend(skip_rate.tolist())
            total_valids += n_valids_sum
            all_predictions.extend(predictions)
            total_loss += loss
            total_corrects += corrects

            total_length = np.sum(length)

        # plt.hist(all_skip_rate)
        # plt.savefig('tmp.png')
        # print(np.average(all_skip_rate))
        predicted_skip_acc = np.sum(
            all_correct_predicted_inference_skips) / total_valids
        acc = total_corrects * 1.0 / total_samples
        return acc, total_loss, np.average(all_skip_rate), predicted_skip_acc

    def testModel(self, sess):
        acc, total_loss, _, _ = self.test(sess, tag='test')
        print('acc = {}, total_loss = {}'.format(acc, total_loss))

    def init_hard_gates(self, threshold, percent):
        train_batch = self.textData.train_batches
        for batch in train_batch:
            batch.init_hard_gates(threshold=threshold,
                                  percent=percent,
                                  max_skip=self.args.maxSkip)

        val_batch = self.textData.val_batches
        for batch in val_batch:
            batch.init_hard_gates(threshold=threshold,
                                  percent=percent,
                                  max_skip=self.args.maxSkip)

        test_batch = self.textData.test_batches
        for batch in test_batch:
            batch.init_hard_gates(threshold=threshold,
                                  percent=percent,
                                  max_skip=self.args.maxSkip)