Exemple #1
0
    def train(self, sess):
        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_val_loss = np.inf

        for e in range(self.args.epochs):
            # training
            trainBatches = self.textData.get_batches(tag='train')
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_steps = 0
            for nextBatch in tqdm(trainBatches):
                cnt += 1
                self.globalStep += 1

                for sample in nextBatch.samples:
                    total_steps += sample.length

                ops, feed_dict = self.model.step(nextBatch, test=False)

                _, loss, trainPerplexity = sess.run(ops, feed_dict)

                totalTrainLoss += loss

                # average across samples in this step
                trainPerplexity = np.mean(trainPerplexity)
                self.summaryWriter.add_summary(utils.makeSummary({"trainLoss": loss}), self.globalStep)
                self.summaryWriter.add_summary(utils.makeSummary({"trainPerplexity": trainPerplexity}), self.globalStep)

            # compute perplexity over all samples in an epoch
            trainPerplexity = np.exp(totalTrainLoss/total_steps)

            print('\nepoch = {}, Train, loss = {}, perplexity = {}'.
                  format(e, totalTrainLoss, trainPerplexity))
            out.write('\nepoch = {}, loss = {}, perplexity = {}\n'.
                  format(e, totalTrainLoss, trainPerplexity))
            out.flush()

            valLoss, val_num = self.test(sess, tag='val')

            testLoss, test_num = self.test(sess, tag='test')

            valPerplexity = np.exp(valLoss/val_num)
            testPerplexity = np.exp(testLoss/test_num)

            print('Val, loss = {}, perplexity = {}'.
                  format(valLoss, valPerplexity))
            out.write('Val, loss = {}, perplexity = {}\n'.
                  format(valLoss, valPerplexity))

            print('Test, loss = {}, perplexity = {}'.
                  format(testLoss, testPerplexity))
            out.write('Test, loss = {}, perplexity = {}\n'.
                  format(testLoss, testPerplexity))

            # we do not use cross val currently, just train, then evaluate
            #if True:
            if valLoss < current_val_loss:
                current_val_loss = valLoss
                print('New val loss {} at epoch {}'.format(valLoss, e))
                out.write('New val loss {} at epoch {}\n'.format(valLoss, e))
                save_path = self.saver.save(sess, save_path=self.model_name)
                print('model saved at {}'.format(save_path))
                out.write('model saved at {}\n'.format(save_path))

            out.flush()
        out.close()
Exemple #2
0
    def train(self, sess):
        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_valAcc = 0.0

        for e in range(self.args.epochs):
            # training
            #trainBatches = self.textData.train_batches
            trainBatches = self.textData.get_batches(tag='train')
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_samples = 0
            total_corrects = 0
            for nextBatch in tqdm(trainBatches):
                cnt += 1
                self.globalStep += 1

                total_samples += nextBatch.batch_size
                ops, feed_dict = self.model.step(nextBatch, test=False)

                _, loss, predictions, corrects = sess.run(ops, feed_dict)
                total_corrects += corrects
                totalTrainLoss += loss

                # average across samples in this step
                self.summaryWriter.add_summary(utils.makeSummary({"trainLoss": loss}), self.globalStep)
            # compute perplexity over all samples in an epoch
            trainAcc = total_corrects*1.0/total_samples
            print('\nepoch = {}, Train, loss = {}, trainAcc = {}'.
                  format(e, totalTrainLoss, trainAcc))
            out.write('\nepoch = {}, loss = {}, trainAcc = {}\n'.
                  format(e, totalTrainLoss, trainAcc))
            out.flush()
            valAcc, valLoss = self.test(sess, tag='val')

            print('Val, loss = {}, valAcc = {}'.
                  format(valLoss, valAcc))
            out.write('Val, loss = {}, valAcc = {}\n'.
                  format(valLoss, valAcc))

            testAcc, testLoss = self.test(sess, tag='test')
            print('Test, loss = {}, testAcc = {}'.
                  format(testLoss, testAcc))
            out.write('Test, loss = {}, testAcc = {}\n'.
                  format(testLoss, testAcc))

            out.flush()

            # we do not use cross val currently, just train, then evaluate
            if valAcc >= current_valAcc:
                current_valAcc = valAcc
                print('New valAcc {} at epoch {}'.format(valAcc, e))
                out.write('New valAcc {} at epoch {}\n'.format(valAcc, e))
                save_path = self.saver.save(sess, save_path=self.model_name)
                print('model saved at {}'.format(save_path))
                out.write('model saved at {}\n'.format(save_path))

            out.flush()
        out.close()
Exemple #3
0
	def train(self, sess):
		#sess = tf_debug.LocalCLIDebugWrapperSession(sess)

		print('Start training')

		out = open(self.outFile, 'w', 1)
		out.write(self.outFile + '\n')
		utils.writeInfo(out, self.args)

		current_valAcc = 0.0

		for e in range(self.args.epochs):
			# training
			trainBatches = self.textData.train_batches
			#trainBatches = self.textData.get_batches(tag='train')
			totalTrainLoss = 0.0

			# cnt of batches
			cnt = 0

			total_samples = 0
			total_corrects = 0
			all_skip_rate = []
			for idx, nextBatch in enumerate(tqdm(trainBatches)):
				cnt += 1
				#nextBatch = trainBatches[227]
				self.globalStep += 1

				total_samples += nextBatch.batch_size

				ops, feed_dict, length = self.model.step(nextBatch, test=False)
				# skip_rate: batch_size * n_samples
				_, loss, predictions, corrects, skip_rate, skip_flag = sess.run(ops, feed_dict)
				all_skip_rate.extend(skip_rate.tolist())
				#print(loss, idx)
				total_corrects += corrects
				totalTrainLoss += loss

				self.summaryWriter.add_summary(utils.makeSummary({"train_loss": loss}), self.globalStep)
				break
			trainAcc = total_corrects * 1.0 / (total_samples*self.args.nSamples)
			train_skip_rate = np.average(all_skip_rate)
			print('\nepoch = {}, Train, loss = {}, trainAcc = {}, train_skip_rate = {}'.
			      format(e, totalTrainLoss, trainAcc, train_skip_rate))
			out.write('\nepoch = {}, loss = {}, trainAcc = {}, train_skip_rate = {}\n'.
			          format(e, totalTrainLoss, trainAcc, train_skip_rate))
			out.flush()
			continue
			valAcc, valLoss, val_skip_rate = self.test(sess, tag='val')
			testAcc, testLoss, test_skip_rate = self.test(sess, tag='test')

			print('\tVal, loss = {}, valAcc = {}, val_skip_rate = {}'.
			      format(valLoss, valAcc, val_skip_rate))
			out.write('\tVal, loss = {}, valAcc = {}, val_skip_rate = {}\n'.
			          format(valLoss, valAcc, val_skip_rate))

			print('\tTest, loss = {}, testAcc = {}, test_skip_rate = {}'.
			      format(testLoss, testAcc, test_skip_rate))
			out.write('\tTest, loss = {}, testAcc = {}, test_skip_rate = {}\n'.
			          format(testLoss, testAcc, test_skip_rate))

			self.summaryWriter.add_summary(utils.makeSummary({"train_acc": trainAcc}), e)
			self.summaryWriter.add_summary(utils.makeSummary({"val_acc": valAcc}), e)
			self.summaryWriter.add_summary(utils.makeSummary({"test_acc": testAcc}), e)
			self.summaryWriter.add_summary(utils.makeSummary({"train_skip_rate": train_skip_rate}), e)
			self.summaryWriter.add_summary(utils.makeSummary({"val_skip_rate": val_skip_rate}), e)
			self.summaryWriter.add_summary(utils.makeSummary({"test_skip_rate": test_skip_rate}), e)
			# we do not use cross val currently, just train, then evaluate
			if valAcc >= current_valAcc:
				# with open('skip_train.pkl', 'wb') as f:
				# 	p.dump(all_skip_rate, f)
				current_valAcc = valAcc
				print('New valAcc {} at epoch {}'.format(valAcc, e))
				out.write('New valAcc {} at epoch {}\n'.format(valAcc, e))
				save_path = self.saver.save(sess, save_path=self.model_name)
				print('model saved at {}'.format(save_path))
				out.write('model saved at {}\n'.format(save_path))

			out.flush()
		out.close()
Exemple #4
0
    def train(self, sess):
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        print('Start training')

        out = open(self.outFile, 'w', 1)
        out.write(self.outFile + '\n')
        utils.writeInfo(out, self.args)

        current_val_f1_micro_unweighted = 0.0

        for e in range(self.args.epochs):
            # training
            trainBatches = self.textData.get_batches(tag='train',
                                                     augment=self.args.augment)
            #trainBatches = self.textData.train_batches
            totalTrainLoss = 0.0

            # cnt of batches
            cnt = 0

            total_samples = 0
            total_corrects = 0

            all_predictions = []
            all_labels = []
            all_sample_weights = []

            for idx, nextBatch in enumerate(tqdm(trainBatches)):

                cnt += 1
                self.globalStep += 1
                total_samples += nextBatch.batch_size
                # print(idx)

                ops, feed_dict, labels, sample_weights = self.model.step(
                    nextBatch, test=False)
                _, loss, predictions, corrects = sess.run(ops, feed_dict)
                all_predictions.extend(predictions)
                all_labels.extend(labels)
                all_sample_weights.extend(sample_weights)
                total_corrects += corrects
                totalTrainLoss += loss

                self.summaryWriter.add_summary(
                    utils.makeSummary({"train_loss": loss}), self.globalStep)
                #break
            trainAcc = total_corrects * 1.0 / total_samples

            # calculate f1 score for train (weighted/unweighted)
            train_f1_micro, train_f1_macro, train_p_micro, train_r_micro, train_p_macro, train_r_macro\
             = self.cal_F1(y_pred=all_predictions, y_true=all_labels)
            train_f1_micro_w, train_f1_macro_w, train_p_micro_w, train_r_micro_w, train_p_macro_w, train_r_macro_w\
             = self.cal_F1(y_pred=all_predictions, y_true=all_labels,
                                                             sample_weight=all_sample_weights)

            print(
                '\nepoch = {}, Train, loss = {}, trainAcc = {}, train_f1_micro = {}, train_f1_macro = {},'
                ' train_f1_micro_w = {}, train_f1_macro_w = {}'.format(
                    e, totalTrainLoss, trainAcc, train_f1_micro,
                    train_f1_macro, train_f1_micro_w, train_f1_macro_w))
            print(
                '\ttrain_p_micro = {}, train_r_micro = {}, train_p_macro = {}, train_r_macro = {}'
                .format(train_p_micro, train_r_micro, train_p_macro,
                        train_r_macro))
            print(
                '\ttrain_p_micro_w = {}, train_r_micro_w = {}, train_p_macro_w = {}, train_r_macro_w = {}'
                .format(train_p_micro_w, train_r_micro_w, train_p_macro_w,
                        train_r_macro_w))

            out.write(
                '\nepoch = {}, loss = {}, trainAcc = {}, train_f1_micro = {}, train_f1_macro = {},'
                ' train_f1_micro_w = {}, train_f1_macro_w = {}\n'.format(
                    e, totalTrainLoss, trainAcc, train_f1_micro,
                    train_f1_macro, train_f1_micro_w, train_f1_macro_w))
            out.write(
                '\ttrain_p_micro = {}, train_r_micro = {}, train_p_macro = {}, train_r_macro = {}\n'
                .format(train_p_micro, train_r_micro, train_p_macro,
                        train_r_macro))
            out.write(
                '\ttrain_p_micro_w = {}, train_r_micro_w = {}, train_p_macro_w = {}, train_r_macro_w = {}\n'
                .format(train_p_micro_w, train_r_micro_w, train_p_macro_w,
                        train_r_macro_w))
            #continue

            out.flush()

            # calculate f1 score for val (weighted/unweighted)
            valAcc, valLoss, val_f1_micro, val_f1_macro, val_f1_micro_w, val_f1_macro_w,\
            val_p_micro, val_r_micro, val_p_macro, val_r_macro, val_p_micro_w, val_r_micro_w, val_p_macro_w, val_r_macro_w\
             = self.test(sess, tag='val')

            print(
                '\n\tVal, loss = {}, valAcc = {}, val_f1_micro = {}, val_f1_macro = {}, val_f1_micro_w = {}, val_f1_macro_w = {}'
                .format(valLoss, valAcc, val_f1_micro, val_f1_macro,
                        val_f1_micro_w, val_f1_macro_w))
            print(
                '\t\t val_p_micro = {}, val_r_micro = {}, val_p_macro = {}, val_r_macro = {}'
                .format(val_p_micro, val_r_micro, val_p_macro, val_r_macro))
            print(
                '\t\t val_p_micro_w = {}, val_r_micro_w = {}, val_p_macro_w = {}, val_r_macro_w = {}'
                .format(val_p_micro_w, val_r_micro_w, val_p_macro_w,
                        val_r_macro_w))

            out.write(
                '\n\tVal, loss = {}, valAcc = {}, val_f1_micro = {}, val_f1_macro = {}, val_f1_micro_w = {}, val_f1_macro_w = {}\n'
                .format(valLoss, valAcc, val_f1_micro, val_f1_macro,
                        val_f1_micro_w, val_f1_macro_w))
            out.write(
                '\t\t val_p_micro = {}, val_r_micro = {}, val_p_macro = {}, val_r_macro = {}\n'
                .format(val_p_micro, val_r_micro, val_p_macro, val_r_macro))
            out.write(
                '\t\t val_p_micro_w = {}, val_r_micro_w = {}, val_p_macro_w = {}, val_r_macro_w = {}\n'
                .format(val_p_micro_w, val_r_micro_w, val_p_macro_w,
                        val_r_macro_w))

            # calculate f1 score for test (weighted/unweighted)
            testAcc, testLoss, test_f1_micro, test_f1_macro, test_f1_micro_w, test_f1_macro_w,\
            test_p_micro, test_r_micro, test_p_macro, test_r_macro, test_p_micro_w, test_r_micro_w, test_p_macro_w, test_r_macro_w\
             = self.test(sess, tag='test')

            print(
                '\n\ttest, loss = {}, testAcc = {}, test_f1_micro = {}, test_f1_macro = {}, test_f1_micro_w = {}, test_f1_macro_w = {}'
                .format(testLoss, testAcc, test_f1_micro, test_f1_macro,
                        test_f1_micro_w, test_f1_macro_w))
            print(
                '\t\t test_p_micro = {}, test_r_micro = {}, test_p_macro = {}, test_r_macro = {}'
                .format(test_p_micro, test_r_micro, test_p_macro,
                        test_r_macro))
            print(
                '\t\t test_p_micro_w = {}, test_r_micro_w = {}, test_p_macro_w = {}, test_r_macro_w = {}'
                .format(test_p_micro_w, test_r_micro_w, test_p_macro_w,
                        test_r_macro_w))

            out.write(
                '\n\ttest, loss = {}, testAcc = {}, test_f1_micro = {}, test_f1_macro = {}, test_f1_micro_w = {}, test_f1_macro_w = {}\n'
                .format(testLoss, testAcc, test_f1_micro, test_f1_macro,
                        test_f1_micro_w, test_f1_macro_w))
            out.write(
                '\t\t test_p_micro = {}, test_r_micro = {}, test_p_macro = {}, test_r_macro = {}\n'
                .format(test_p_micro, test_r_micro, test_p_macro,
                        test_r_macro))
            out.write(
                '\t\t test_p_micro_w = {}, test_r_micro_w = {}, test_p_macro_w = {}, test_r_macro_w = {}\n'
                .format(test_p_micro_w, test_r_micro_w, test_p_macro_w,
                        test_r_macro_w))

            out.flush()

            self.summaryWriter.add_summary(
                utils.makeSummary({"train_acc": trainAcc}), e)
            self.summaryWriter.add_summary(
                utils.makeSummary({"val_acc": valAcc}), e)

            # use val_f1_micro (unweighted) as metric
            if val_f1_micro >= current_val_f1_micro_unweighted:
                current_val_f1_micro_unweighted = val_f1_micro
                print('New val f1_micro {} at epoch {}'.format(
                    val_f1_micro, e))
                out.write('New val f1_micro {} at epoch {}\n'.format(
                    val_f1_micro, e))

                save_path = self.saver.save(sess, save_path=self.model_name)
                print('model saved at {}'.format(save_path))
                out.write('model saved at {}\n'.format(save_path))

                test_predictions = self.test(sess, tag='test2')
                print('Writing predictions at epoch {}'.format(e))
                out.write('Writing predictions at epoch {}\n'.format(e))
                test_file = self.write_predictions(test_predictions,
                                                   tag='unweighted')

                print('Writing predictions to {}'.format(test_file))
                out.write('Writing predictions to {}\n'.format(test_file))

            out.flush()
        out.close()