import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from load_data import load_data_names, load_batch_from_data, save_data_to_tfrecord, save_data_to_tfrecord_without_face
import datetime
import random
from mtcnn.mtcnn import mtcnn_handle
from utility.data_utility import write_to_file

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["CUDA_VISIBLE_DEVICES"]="1"
# os.environ["CUDA_VISIBLE_DEVICES"]="2"

mtcnn_h = mtcnn_handle()
now = datetime.datetime.now()
date = now.strftime("%Y-%m-%d-%H-%M")

record_file_name = "records/" + date + '_record.txt'
records = ""
records_count = 0

# dataset_path = "../data/GazeCapture"
dataset_path = "/data/public/gaze/raw/GazeCapture"
train_path = dataset_path + '/' + "train"
val_path = dataset_path + '/' + "validation"
test_path = dataset_path + '/' + "test"

img_cols = 64
img_rows = 64
Exemple #2
0
    def train(self,
              args,
              ckpt,
              plot_ckpt,
              lr=1e-3,
              batch_size=128,
              max_epoch=1000,
              min_delta=1e-4,
              patience=10,
              print_per_epoch=10):
        ifCheck = False

        # --------------------------
        train_data_eye_left, val_data_eye_left = self.organize_extra_eye_data(
            args, "left")
        train_data_eye_right, val_data_eye_right = self.organize_extra_eye_data(
            args, "right")

        # -----------------------------
        print("------ finish processing extra data  --------")

        train_names = load_data_names(train_path)
        val_names = load_data_names(val_path)

        train_num = len(train_names)
        val_num = len(val_names)

        print("train_num: ", train_num)
        print("test_num: ", val_num)

        MaxIters = train_num / batch_size
        n_batches = MaxIters

        val_chunk_size = 1000
        MaxTestIters = val_num / val_chunk_size
        val_n_batches = val_chunk_size / batch_size

        print("MaxIters: ", MaxIters)
        print("MaxTestIters: ", MaxTestIters)

        print('Train on %s samples, validate on %s samples' %
              (train_num, val_num))

        # Define loss and optimizer
        pred_xy, pred_ang_left, pred_ang_right = self.pred
        self.cost1 = tf.losses.mean_squared_error(self.y, pred_xy)
        self.cost2 = tf.losses.mean_squared_error(self.y, pred_ang_left)**2
        self.cost3 = tf.losses.mean_squared_error(self.y, pred_ang_right)**2

        self.optimizer1 = tf.train.AdamOptimizer(learning_rate=lr).minimize(
            self.cost1)
        self.optimizer2 = tf.train.AdamOptimizer(learning_rate=lr).minimize(
            self.cost2)
        self.optimizer3 = tf.train.AdamOptimizer(learning_rate=lr).minimize(
            self.cost3)

        # Evaluate model
        self.err1 = tf.reduce_mean(
            tf.sqrt(
                tf.reduce_sum(tf.squared_difference(pred_xy, self.y), axis=1)))
        self.err2 = compute_angle_error(self.y, pred_ang_left)
        self.err3 = compute_angle_error(self.y, pred_ang_right)

        train_loss_history = []
        train_err_history = []
        val_loss_history = []
        val_err_history = []

        train_loss_history_eye_left = []
        train_err_history_eye_left = []
        val_loss_history_eye_left = []
        val_err_history_eye_left = []

        train_loss_history_eye_right = []
        train_err_history_eye_right = []
        val_loss_history_eye_right = []
        val_err_history_eye_right = []

        best_loss = np.Inf

        # Create the collection
        tf.get_collection("validation_nodes")
        # Add stuff to the collection.
        tf.add_to_collection("validation_nodes", self.eye_left)
        tf.add_to_collection("validation_nodes", self.eye_right)
        tf.add_to_collection("validation_nodes", self.face)
        tf.add_to_collection("validation_nodes", self.face_mask)
        tf.add_to_collection("validation_nodes", pred_xy)
        tf.add_to_collection("validation_nodes", pred_ang_left)
        tf.add_to_collection("validation_nodes", pred_ang_right)

        # variables_to_restore = [var for var in tf.global_variables()]
        # saver = tf.train.Saver(variables_to_restore)

        saver = tf.train.Saver(max_to_keep=0)

        # Initializing the variables
        init = tf.global_variables_initializer()
        # TODO://////
        # tf.reset_default_graph()
        # Launch the graph

        with tf.Session() as sess:
            sess.run(init)
            # TODO://////
            writer = tf.summary.FileWriter("logs", sess.graph)

            # saver.restore(sess, "./my_model/pretrained/model_4_1800_train_error_3.5047762_val_error_5.765135765075684")

            # saver.restore(sess, "./my_model/2018-09-18-11-01/model_1_300_train_error_history_2.8944669_val_error_history_3.092479933391918")

            # print " pass the restoring !!!!"

            mtcnn_h = mtcnn_handle()

            random.shuffle(val_names)

            # Keep training until reach max iterations
            for n_epoch in range(1, max_epoch + 1):
                print("vvvvvvvvvvvvvvvvvvv")
                print("n_epoch: ", n_epoch)
                epoch_start = timeit.default_timer()
                iter_start = None

                random.shuffle(train_names)

                iterTest = 0
                i_left = 0
                i_right = 0

                for iter in range(int(MaxIters)):

                    start = timeit.default_timer()

                    # print ("--------------------------------")
                    # print ("iter: ", iter)
                    train_start = iter * batch_size
                    train_end = (iter + 1) * batch_size

                    batch_train_data = load_batch_from_data(
                        mtcnn_h,
                        train_names,
                        dataset_path,
                        batch_size,
                        img_ch,
                        img_cols,
                        img_rows,
                        train_start=train_start,
                        train_end=train_end)
                    batch_train_data = prepare_data(batch_train_data)

                    print('Loading and preparing training data: %.1fs' %
                          (timeit.default_timer() - start))
                    start = timeit.default_timer()

                    # # Run optimization op (backprop)
                    sess.run(self.optimizer1, feed_dict={self.eye_left: batch_train_data[0], \
                       self.eye_right: batch_train_data[1], self.face: batch_train_data[2], \
                       self.face_mask: batch_train_data[3], self.y: batch_train_data[4]})

                    train_batch_loss, train_batch_err = sess.run([self.cost1, self.err1], feed_dict={self.eye_left: batch_train_data[0], \
                       self.eye_right: batch_train_data[1], self.face: batch_train_data[2], \
                       self.face_mask: batch_train_data[3], self.y: batch_train_data[4]})

                    for _ in range(5):
                        batch_train_data_eye_left, i_left = next_batch_universal(
                            train_data_eye_left, batch_size, i_left)

                        sess.run(self.optimizer2, feed_dict={self.eye_left: batch_train_data_eye_left[0], \
                           self.y: batch_train_data_eye_left[1]})

                        train_batch_loss_eye_left, train_batch_err_eye_left = sess.run([self.cost2, self.err2], feed_dict={self.eye_left: batch_train_data_eye_left[0], \
                           self.y: batch_train_data_eye_left[1]})

                        batch_train_data_eye_right, i_right = next_batch_universal(
                            train_data_eye_right, batch_size, i_right)

                        sess.run(self.optimizer3, feed_dict={self.eye_right: batch_train_data_eye_right[0], \
                           self.y: batch_train_data_eye_right[1]})

                        train_batch_loss_eye_right, train_batch_err_eye_right = sess.run([self.cost3, self.err3], feed_dict={self.eye_right: batch_train_data_eye_right[0], \
                           self.y: batch_train_data_eye_right[1]})

                    train_loss_history.append(train_batch_loss)
                    train_err_history.append(train_batch_err)

                    train_loss_history_eye_left.append(
                        train_batch_loss_eye_left)
                    train_err_history_eye_left.append(train_batch_err_eye_left)

                    train_loss_history_eye_right.append(
                        train_batch_loss_eye_right)
                    train_err_history_eye_right.append(
                        train_batch_err_eye_right)

                    print('Training on batch: %.1fs' %
                          (timeit.default_timer() - start))

                    if iter % 30 == 0:
                        ifCheck = True

                    if ifCheck:

                        start = timeit.default_timer()

                        if iterTest + 1 >= MaxTestIters:
                            iterTest = 0

                        # test_start = iterTest * val_chunk_size
                        # test_end = (iterTest+1) * val_chunk_size
                        test_start = 0
                        test_end = val_chunk_size

                        val_data = load_batch_from_data(mtcnn_h,
                                                        val_names,
                                                        dataset_path,
                                                        val_chunk_size,
                                                        img_ch,
                                                        img_cols,
                                                        img_rows,
                                                        train_start=test_start,
                                                        train_end=test_end)

                        val_data = prepare_data(val_data)

                        print('Loading and preparing val data: %.1fs' %
                              (timeit.default_timer() - start))
                        start = timeit.default_timer()

                        val_loss = 0.
                        val_err = 0.
                        val_loss_eye_left = 0.
                        val_err_eye_left = 0.
                        val_loss_eye_right = 0.
                        val_err_eye_right = 0.

                        i_val_left = 0
                        i_val_right = 0

                        for batch_val_data in next_batch(val_data, batch_size):
                            batch_val_data_eye_left, i_val_left = next_batch_universal(
                                val_data_eye_left, batch_size, i_val_left)
                            batch_val_data_eye_right, i_val_right = next_batch_universal(
                                val_data_eye_right, batch_size, i_val_right)

                            val_batch_loss, val_batch_err = sess.run([self.cost1, self.err1], feed_dict={self.eye_left: batch_val_data[0], \
                                self.eye_right: batch_val_data[1], self.face: batch_val_data[2], \
                                self.face_mask: batch_val_data[3], self.y: batch_val_data[4]})

                            val_batch_loss_eye_left, val_batch_err_eye_left = sess.run([self.cost2, self.err2], \
                                feed_dict={self.eye_left: batch_val_data_eye_left[0], \
                                self.y: batch_val_data_eye_left[1]})

                            val_batch_loss_eye_right, val_batch_err_eye_right = sess.run([self.cost3, self.err3], \
                                feed_dict={self.eye_right: batch_val_data_eye_right[0], \
                                self.y: batch_val_data_eye_right[1]})

                            val_loss += val_batch_loss / val_n_batches
                            val_err += val_batch_err / val_n_batches
                            val_loss_eye_left += val_batch_loss_eye_left / val_n_batches
                            val_err_eye_left += val_batch_err_eye_left / val_n_batches
                            val_loss_eye_right += val_batch_loss_eye_right / val_n_batches
                            val_err_eye_right += val_batch_err_eye_right / val_n_batches

                        print("val_loss: ", val_loss, "val_err: ", val_err)
                        print("val_loss_left: ", val_loss_eye_left,
                              "val_err_left: ", val_err_eye_left)
                        print("val_loss_right: ", val_loss_eye_right,
                              "val_err_right: ", val_err_eye_right)

                        iterTest += 1

                        print('Testing on chunk: %.1fs' %
                              (timeit.default_timer() - start))
                        start = timeit.default_timer()

                        if iter_start:
                            print('batch iters runtime: %.1fs' %
                                  (timeit.default_timer() - iter_start))
                        else:
                            iter_start = timeit.default_timer()

                        print("now: ", now)
                        print("learning rate: ", lr)

                        print(
                            'Epoch %s/%s Iter %s, train loss: %.5f, train error: %.5f, val loss: %.5f, val error: %.5f'
                            % (n_epoch, max_epoch, iter,
                               np.mean(train_loss_history),
                               np.mean(train_err_history),
                               np.mean(val_loss_history),
                               np.mean(val_err_history)))

                        print(
                            'Epoch %s/%s Iter %s, train val_loss_eye_left: %.5f, train error_eye_left: %.5f, val loss_eye_left: %.5f, val error_eye_left: %.5f'
                            % (n_epoch, max_epoch, iter,
                               np.mean(train_loss_history_eye_left),
                               np.mean(train_err_history_eye_left),
                               np.mean(val_loss_history_eye_left),
                               np.mean(val_err_history_eye_left)))

                        print(
                            'Epoch %s/%s Iter %s, train loss_eye_right: %.5f, train error_eye_right: %.5f, val loss_eye_right: %.5f, val error_eye_right: %.5f'
                            % (n_epoch, max_epoch, iter,
                               np.mean(train_loss_history_eye_right),
                               np.mean(train_err_history_eye_right),
                               np.mean(val_loss_history_eye_right),
                               np.mean(val_err_history_eye_right)))

                        val_loss_history.append(val_loss)
                        val_err_history.append(val_err)

                        val_loss_history_eye_left.append(val_loss_eye_left)
                        val_err_history_eye_left.append(val_err_eye_left)

                        val_loss_history_eye_right.append(val_loss_eye_right)
                        val_err_history_eye_right.append(val_err_eye_right)

                        plot_loss(np.array(train_loss_history),
                                  np.array(train_err_history),
                                  np.array(val_loss_history),
                                  np.array(val_err_history),
                                  start=0,
                                  per=1,
                                  save_file=plot_ckpt + "/cumul_loss_" +
                                  str(n_epoch) + "_" + str(iter) + ".png")
                        plot_loss(np.array(train_loss_history_eye_left),
                                  np.array(train_err_history_eye_left),
                                  np.array(val_loss_history_eye_left),
                                  np.array(val_err_history_eye_left),
                                  start=0,
                                  per=1,
                                  save_file=plot_ckpt + "/cumul_loss_" +
                                  str(n_epoch) + "_" + str(iter) +
                                  "_eye_left.png")
                        plot_loss(np.array(train_loss_history_eye_right),
                                  np.array(train_err_history_eye_right),
                                  np.array(val_loss_history_eye_right),
                                  np.array(val_err_history_eye_right),
                                  start=0,
                                  per=1,
                                  save_file=plot_ckpt + "/cumul_loss_" +
                                  str(n_epoch) + "_" + str(iter) +
                                  "_eye_right.png")

                        save_path = ckpt + "model_" + str(n_epoch) + "_" + str(
                            iter) + "_train_error_history_%s" % (
                                np.mean(train_err_history)
                            ) + "_val_error_history_%s" % (
                                np.mean(val_err_history))

                        save_path = saver.save(sess, save_path)
                        print("args.learning_rate: ", args.learning_rate)
                        print("Model saved in file: %s" % save_path)

                        ifCheck = False

                        print('Saving models and plotting loss: %.1fs' %
                              (timeit.default_timer() - start))

                print('epoch runtime: %.1fs' %
                      (timeit.default_timer() - epoch_start))

            return train_loss_history, train_err_history, val_loss_history, val_err_history