Пример #1
0
def main_func(_):
    print(FLAGS)
    save_path = FLAGS.train_dir + "tfFile/"
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    print("1. Loading WordVocab data...")
    wordVocab = Vocab()
    wordVocab.fromText_format3(FLAGS.train_dir, FLAGS.wordvec_path)
    sys.stdout.flush()

    prepare = Prepare()
    if FLAGS.hasTfrecords:
        print("2. Has Tfrecords File---Train---")
        total_lines = prepare.processTFrecords_hasDone(savePath=save_path, taskNumber=FLAGS.taskNumber)
    else:
        print("2. Start generating TFrecords File--train...")
        total_lines = prepare.processTFrecords(wordVocab, savePath=save_path, max_len=FLAGS.max_len,
                                               taskNumber=FLAGS.taskNumber)
    print("totalLines_train_0:", total_lines[0])
    print("totalLines_train_1:", total_lines[1])
    sys.stdout.flush()

    test_path = FLAGS.train_dir + FLAGS.test_path
    if FLAGS.hasTfrecords:
        print("3. Has TFrecords File--test...")
        totalLines_test = prepare.processTFrecords_test_hasDone(test_path=test_path, taskNumber=1)
    else:
        print("3. Start generating TFrecords File--test...")
        totalLines_test = prepare.processTFrecords_test(wordVocab,
                                                        savePath=save_path,
                                                        test_path=test_path,
                                                        max_len=FLAGS.max_len,
                                                        taskNumber=1)
    print("totalLines_test:", totalLines_test)
    sys.stdout.flush()

    print("4. Start loading TFrecords File...")
    taskNameList = []
    for i in range(FLAGS.taskNumber):
        string = FLAGS.train_dir + 'tfFile/train-' + str(i) + '.tfrecords'
        taskNameList.append(string)
    print("taskNameList: ", taskNameList)
    sys.stdout.flush()

    ################
    n = total_lines[0] / total_lines[1] + 1 if \
        total_lines[0] % total_lines[1] != 0 else \
        total_lines[0] / total_lines[1]
    print("n: ", n)
    num_batches_per_epoch_train_0 = int(total_lines[0] / FLAGS.batch_size) + 1 if \
        total_lines[0] % FLAGS.batch_size != 0 else int(
        total_lines[0] / FLAGS.batch_size)
    print("batch_numbers_train_0:", num_batches_per_epoch_train_0)
    batch_size_1 = FLAGS.batch_size / n

    num_batches_per_epoch_test = int(totalLines_test / FLAGS.batch_size) + 1 if \
        totalLines_test % FLAGS.batch_size != 0 else \
        int(totalLines_test / FLAGS.batch_size)
    print("batch_numbers_test:", num_batches_per_epoch_test)

    with tf.Graph().as_default():
        all_test = prepare.read_records(
            taskname=save_path + "test-0.tfrecords",
            max_len=FLAGS.max_len,
            epochs=FLAGS.num_epochs,
            batch_size=FLAGS.batch_size)

        all_train_0 = prepare.read_records(
            taskname=taskNameList[0],
            max_len=FLAGS.max_len,
            epochs=FLAGS.num_epochs,
            batch_size=FLAGS.batch_size)

        all_train_1 = prepare.read_records(
            taskname=taskNameList[1],
            max_len=FLAGS.max_len,
            epochs=FLAGS.num_epochs,
            batch_size=batch_size_1)

        print("Loading Model...")
        sys.stdout.flush()

        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            print("------------train model--------------")
            m_train = mtl_model.MTLModel(max_len=FLAGS.max_len,
                                         filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                                         num_filters=FLAGS.num_filters,
                                         num_hidden=FLAGS.num_hidden,
                                         word_vocab=wordVocab,
                                         l2_reg_lambda=FLAGS.l2_reg_lambda,
                                         learning_rate=FLAGS.learning_rate,
                                         adv=FLAGS.adv)
            m_train.build_train_op()
            print("\n\n")

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)
            init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            has_pre_trained_model = False
            out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, "runs"))

            print(out_dir)
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            else:
                print("continue training models")
                ckpt = tf.train.get_checkpoint_state(out_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    print("-------has_pre_trained_model--------")
                    print(ckpt.model_checkpoint_path)
                    has_pre_trained_model = True
                    sys.stdout.flush()

            checkpoint_prefix = os.path.join(out_dir, "model")
            if has_pre_trained_model:
                print("Restoring model from " + ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("DONE!")
                sys.stdout.flush()

            def dev_whole(num_batches_per_epoch_test):
                accuracies = []
                losses = []

                for j in range(num_batches_per_epoch_test):
                    input_y_test, input_left_test, input_centre_test = sess.run(
                        [all_test[0], all_test[1], all_test[2]])
                    loss, accuracy, loss_adv, loss_ce = sess.run(
                        [m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2],
                         m_train.tensors[1][3]],
                        feed_dict={
                            m_train.input_task_0: 0,
                            m_train.input_left_0: input_left_real_0,
                            m_train.input_right_0: input_centre_real_0,
                            m_train.input_y_0: input_y_real_0,
                            m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                            m_train.input_task_1: 1,
                            m_train.input_left_1: input_left_test,
                            m_train.input_right_1: input_centre_test,
                            m_train.input_y_1: input_y_test,
                        })
                    losses.append(loss_ce)
                    accuracies.append(accuracy)
                # print("specfic_prob: ", prob_test)
                sys.stdout.flush()
                return np.mean(np.array(losses)), np.mean(np.array(accuracies))

            def overfit(dev_accuracy):
                n = len(dev_accuracy)
                if n < 4:
                    return False
                for i in range(n - 4, n):
                    if dev_accuracy[i] > dev_accuracy[i - 1]:
                        return False
                return True

            dev_accuracy = []
            total_train_loss = []

            train_loss_0 = 0
            train_loss_1 = 0
            loss_task_0 = 0
            loss_task_1 = 0
            adv_0 = 0
            adv_1 = 0
            acc_1 = 0
            count = 0
            try:
                while not coord.should_stop():  ## for each epoch
                    for i in range(num_batches_per_epoch_train_0 * FLAGS.num_epochs):  ## for each batch
                        input_y_real_0, input_left_real_0, input_centre_real_0 = sess.run([all_train_0[0],
                                                                                           all_train_0[1],
                                                                                           all_train_0[2]])
                        input_y_real_1, input_left_real_1, input_centre_real_1 = sess.run([all_train_1[0],
                                                                                           all_train_1[1],
                                                                                           all_train_1[2]])

                        # acc, loss, loss_adv = m_train.tensors[0]
                        # _, current_step_0, loss_0, accuracy_0, loss_adv_0 = sess.run(
                        #     [m_train.train_ops[0][0], m_train.train_ops[0][1],
                        #      m_train.tensors[0][1], m_train.tensors[0][0], m_train.tensors[0][2]],
                        #     feed_dict={
                        #         m_train.input_task_0: 0,
                        #         m_train.input_left_0: input_left_real_0,
                        #         m_train.input_right_0: input_centre_real_0,
                        #         m_train.input_y_0: input_y_real_0,
                        #         m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                        #         m_train.input_task_1: 1,
                        #         m_train.input_left_1: input_left_real_1,
                        #         m_train.input_right_1: input_centre_real_1,
                        #         m_train.input_y_1: input_y_real_1,
                        #     })
                        # all_loss_adv += loss_adv_0
                        # train_acc += accuracy_0
                        # train_loss_0 += loss_0
                        # train_loss += loss_0
                        #
                        # _, current_step_1, loss_1, accuracy_1, loss_adv_1 = sess.run(
                        #     [m_train.train_ops[1][0], m_train.train_ops[1][1],
                        #      m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2]],
                        #     feed_dict={
                        #         m_train.input_task_0: 0,
                        #         m_train.input_left_0: input_left_real_0,
                        #         m_train.input_right_0: input_centre_real_0,
                        #         m_train.input_y_0: input_y_real_0,
                        #         m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                        #         m_train.input_task_1: 1,
                        #         m_train.input_left_1: input_left_real_1,
                        #         m_train.input_right_1: input_centre_real_1,
                        #         m_train.input_y_1: input_y_real_1,
                        #     })
                        _, loss_0, accuracy_0, loss_adv_0, loss_ce_0 = sess.run(
                            [m_train.train_ops[0],
                             m_train.tensors[0][1], m_train.tensors[0][0], m_train.tensors[0][2],
                             m_train.tensors[0][3]],
                            feed_dict={
                                m_train.input_task_0: 0,
                                m_train.input_left_0: input_left_real_0,
                                m_train.input_right_0: input_centre_real_0,
                                m_train.input_y_0: input_y_real_0,
                                m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                                m_train.input_task_1: 1,
                                m_train.input_left_1: input_left_real_1,
                                m_train.input_right_1: input_centre_real_1,
                                m_train.input_y_1: input_y_real_1,
                            })
                        train_loss_0 += loss_0
                        loss_task_0 += loss_ce_0
                        adv_0 += loss_adv_0

                        _, loss_1, accuracy_1, loss_adv_1, loss_ce_1 = sess.run(
                            [m_train.train_ops[1],
                             m_train.tensors[1][1], m_train.tensors[1][0], m_train.tensors[1][2],
                             m_train.tensors[1][3]],
                            feed_dict={
                                m_train.input_task_0: 0,
                                m_train.input_left_0: input_left_real_0,
                                m_train.input_right_0: input_centre_real_0,
                                m_train.input_y_0: input_y_real_0,
                                m_train.dropout_keep_prob: FLAGS.dropout_keep_prob,
                                m_train.input_task_1: 1,
                                m_train.input_left_1: input_left_real_1,
                                m_train.input_right_1: input_centre_real_1,
                                m_train.input_y_1: input_y_real_1,
                            })
                        train_loss_1 += loss_1
                        loss_task_1 += loss_ce_1
                        adv_1 += loss_adv_1
                        acc_1 += accuracy_1

                        count += 1
                        if count % 500 == 0:
                            print("loss {}, acc {}".format(loss_0, accuracy_0))
                            print("--loss {}, acc {}, loss_adv {}, loss_ce {}--".format(loss_1, accuracy_1, loss_adv_1,
                                                                                        loss_ce_1))
                            sys.stdout.flush()

                        if count % num_batches_per_epoch_train_0 == 0 or \
                                count == num_batches_per_epoch_train_0 * FLAGS.num_epochs:

                            print("train_0: ", count / num_batches_per_epoch_train_0,
                                  " epoch, train_loss_0:", train_loss_0,
                                  "loss_task_0: ", loss_task_0,
                                  "adv_0: ", adv_0)

                            print(
                                "train_1: ", count / num_batches_per_epoch_train_0,
                                " epoch, train_loss_1: ", train_loss_1,
                                "loss_task_1: ", loss_task_1,
                                "adv_1: ", adv_1,
                                "acc_1 : ", acc_1 / num_batches_per_epoch_train_0)

                            total_train_loss.append(loss_task_1)
                            train_loss_0 = 0
                            train_loss_1 = 0
                            loss_task_0 = 0
                            loss_task_1 = 0
                            adv_0 = 0
                            adv_1 = 0
                            acc_1 = 0
                            sys.stdout.flush()

                            print("\n------------------Evaluation:-----------------------")
                            _, accuracy = dev_whole(num_batches_per_epoch_test)
                            dev_accuracy.append(accuracy)
                            print("--------Recently dev accuracy:--------")
                            print(dev_accuracy[-10:])

                            print("--------Recently loss_task_1:------")
                            print(total_train_loss[-10:])
                            if overfit(dev_accuracy):
                                print('-----Overfit!!----')
                                break
                            print("")
                            sys.stdout.flush()

                            # continue
                            path = saver.save(sess, checkpoint_prefix, global_step=count)

                            print("-------------------Saved model checkpoint to {}--------------------".format(path))
                            sys.stdout.flush()
                            output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                                            output_node_names=[
                                                                                                'task_1/prob'])
                            for node in output_graph_def.node:
                                if node.op == 'RefSwitch':
                                    node.op = 'Switch'
                                    for index in xrange(len(node.input)):
                                        if 'moving_' in node.input[index]:
                                            node.input[index] = node.input[index] + '/read'
                                elif node.op == 'AssignSub':
                                    node.op = 'Sub'
                                    if 'use_locking' in node.attr:
                                        del node.attr['use_locking']

                            with tf.gfile.GFile(FLAGS.train_dir + "runs/mtlmodel_specfic.pb", "wb") as f:
                                f.write(output_graph_def.SerializeToString())
                            print("%d ops in the final graph.\n" % len(output_graph_def.node))




            except tf.errors.OutOfRangeError:
                print("Done")
            finally:
                print("--------------------------finally---------------------------")
                print("current_step:", count)
                coord.request_stop()
                coord.join(threads)

            sess.close()
Пример #2
0
def main_func(_):
    print(FLAGS)
    save_path = FLAGS.train_dir + "tfFile/"
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    print("1. Loading WordVocab data...")
    wordVocab = Vocab()
    wordVocab.fromText_format3(FLAGS.train_dir, FLAGS.wordvec_path)
    sys.stdout.flush()

    prepare = Prepare()
    if FLAGS.hasTfrecords:
        print("2. Has Tfrecords File---Train---")
        total_lines = prepare.processTFrecords_hasDone(
            savePath=save_path, taskNumber=FLAGS.taskNumber)
    else:
        print("2. Start generating TFrecords File--train...")
        total_lines = prepare.processTFrecords(wordVocab,
                                               savePath=save_path,
                                               max_len=FLAGS.max_len,
                                               taskNumber=FLAGS.taskNumber)
    print("totalLines_train_0:", total_lines[0])
    print("totalLines_train_1:", total_lines[1])
    sys.stdout.flush()

    test_path = FLAGS.train_dir + FLAGS.test_path
    if FLAGS.hasTfrecords:
        print("3. Has TFrecords File--test...")
        totalLines_test = prepare.processTFrecords_test_hasDone(
            test_path=test_path, taskNumber=1)
    else:
        print("3. Start generating TFrecords File--test...")
        totalLines_test = prepare.processTFrecords_test(wordVocab,
                                                        savePath=save_path,
                                                        test_path=test_path,
                                                        max_len=FLAGS.max_len,
                                                        taskNumber=1)
    print("totalLines_test:", totalLines_test)
    sys.stdout.flush()

    print("4. Start loading TFrecords File...")
    taskNameList = []
    for i in range(FLAGS.taskNumber):
        string = FLAGS.train_dir + 'tfFile/train-' + str(i) + '.tfrecords'
        taskNameList.append(string)
    print("taskNameList: ", taskNameList)
    sys.stdout.flush()

    ################
    n = total_lines[0] / total_lines[1] + 1 if \
        total_lines[0] % total_lines[1] != 0 else \
        total_lines[0] / total_lines[1]
    print("n: ", n)
    num_batches_per_epoch_train_0 = int(total_lines[0] / FLAGS.batch_size) + 1 if \
        total_lines[0] % FLAGS.batch_size != 0 else int(
        total_lines[0] / FLAGS.batch_size)
    print("batch_numbers_train_0:", num_batches_per_epoch_train_0)
    batch_size_1 = FLAGS.batch_size / n

    num_batches_per_epoch_test = int(totalLines_test / FLAGS.batch_size) + 1 if \
        totalLines_test % FLAGS.batch_size != 0 else \
        int(totalLines_test / FLAGS.batch_size)
    print("batch_numbers_test:", num_batches_per_epoch_test)

    with tf.Graph().as_default():
        all_test = prepare.read_records(taskname=save_path +
                                        "test-0.tfrecords",
                                        max_len=FLAGS.max_len,
                                        epochs=FLAGS.num_epochs,
                                        batch_size=FLAGS.batch_size)

        all_train_0 = prepare.read_records(taskname=taskNameList[0],
                                           max_len=FLAGS.max_len,
                                           epochs=FLAGS.num_epochs,
                                           batch_size=FLAGS.batch_size)

        all_train_1 = prepare.read_records(taskname=taskNameList[1],
                                           max_len=FLAGS.max_len,
                                           epochs=FLAGS.num_epochs,
                                           batch_size=batch_size_1)

        print("Loading Model...")
        sys.stdout.flush()

        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            print("------------train model--------------")
            print("---base model---")
            with tf.variable_scope(name_or_scope='base', reuse=None):
                base_model = task_model.Base_model(
                    max_len=FLAGS.max_len,
                    vocab_size=len(wordVocab.word2id),
                    embedding_size=FLAGS.embedding_dim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=FLAGS.num_filters,
                    num_hidden=FLAGS.num_hidden,
                    fix_word_vec=FLAGS.fix_word_vec,
                    word_vocab=wordVocab,
                    l2_reg_lambda=FLAGS.l2_reg_lambda,
                    adv=FLAGS.adv,
                    diff=FLAGS.diff,
                    sharedTag=FLAGS.sharedTag)
                if FLAGS.sharedTag:
                    print("---with shared layer---")
                    base_model.func_shared()
                    base_model.func_adv()
                else:
                    print("\n---without adv---")

            print("\n\n---model_0---")
            with tf.variable_scope(name_or_scope='mtl_0', reuse=None):
                mtlmodel_0 = task_model.MTLModel_0(objects=base_model)

            print("\n\n---model_1---")
            with tf.variable_scope(name_or_scope='mtl_1', reuse=None):
                mtlmodel_1 = task_model.MTLModel_1(objects=base_model)
            print("\n\n")

            global_step_0 = tf.Variable(0,
                                        name="global_step_0",
                                        trainable=False)
            optimizer_0 = tf.train.AdamOptimizer(FLAGS.learning_rate)
            update_ops_0 = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops_0):  # for batch_norm
                train_op_0 = optimizer_0.minimize(
                    mtlmodel_0.total_loss + 0.05 * base_model.loss_adv +
                    base_model.l2_reg_lambda * base_model.l2_loss_adv,
                    global_step=global_step_0)

            global_step_1 = tf.Variable(0,
                                        name="global_step_1",
                                        trainable=False)
            optimizer_1 = tf.train.AdamOptimizer(FLAGS.learning_rate)
            update_ops_1 = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops_1):  # for batch_norm
                train_op_1 = optimizer_1.minimize(
                    mtlmodel_1.total_loss + 0.05 * base_model.loss_adv +
                    base_model.l2_reg_lambda * base_model.l2_loss_adv,
                    global_step=global_step_1)

            # optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            # global_step_0 = tf.Variable(0, name="global_step", trainable=False)
            # grads_and_vars_0 = optimizer.compute_gradients(
            #     mtlmodel_0.total_loss + 0.05 * base_model.loss_adv + base_model.l2_reg_lambda * base_model.l2_loss_adv)
            # train_op_0 = optimizer.apply_gradients(grads_and_vars_0, global_step=global_step_0)
            #
            # global_step_1 = tf.Variable(0, name="global_step", trainable=False)
            # grads_and_vars_1 = optimizer.compute_gradients(
            #     mtlmodel_1.total_loss + 0.05 * base_model.loss_adv + base_model.l2_reg_lambda * base_model.l2_loss_adv)
            # train_op_1 = optimizer.apply_gradients(grads_and_vars_1, global_step=global_step_1)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            has_pre_trained_model = False
            out_dir = os.path.abspath(os.path.join(FLAGS.train_dir, "runs"))

            print(out_dir)
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            else:
                print("continue training models")
                ckpt = tf.train.get_checkpoint_state(out_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    print("-------has_pre_trained_model--------")
                    print(ckpt.model_checkpoint_path)
                    has_pre_trained_model = True
                    sys.stdout.flush()

            checkpoint_prefix = os.path.join(out_dir, "model")
            if has_pre_trained_model:
                print("Restoring model from " + ckpt.model_checkpoint_path)
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("DONE!")
                sys.stdout.flush()

            def dev_whole(num_batches_per_epoch_test):
                accuracies = []
                losses = []

                for j in range(num_batches_per_epoch_test):
                    input_y_test_0, input_left_test_0, input_centre_test_0 = sess.run(
                        [all_test[0], all_test[1], all_test[2]])
                    loss_test, accuracy_test, prob_test = sess.run(
                        [
                            mtlmodel_1.total_loss, mtlmodel_1.acc,
                            mtlmodel_1.specfic_prob
                        ],
                        feed_dict={
                            base_model.input_task: 1,
                            base_model.input_left: input_left_test_0,
                            base_model.input_right: input_centre_test_0,
                            base_model.dropout_keep_prob: 1.0,
                            base_model.input_y: input_y_test_0
                        })
                    losses.append(loss_test)
                    accuracies.append(accuracy_test)
                print("specfic_prob: ", prob_test)
                sys.stdout.flush()
                return np.mean(np.array(losses)), np.mean(np.array(accuracies))

            def overfit(dev_accuracy):
                n = len(dev_accuracy)
                if n < 4:
                    return False
                for i in range(n - 4, n):
                    if dev_accuracy[i] > dev_accuracy[i - 1]:
                        return False
                return True

            dev_accuracy = []
            total_train_loss = []

            train_acc_0 = 0
            train_acc_1 = 0
            task_1_loss = 0
            task_0_loss = 0
            loss_adv_0 = 0
            loss_adv_1 = 0
            try:
                while not coord.should_stop():  ## for each epoch
                    for i in range(num_batches_per_epoch_train_0 *
                                   FLAGS.num_epochs):  ## for each batch
                        input_y_real_0, input_left_real_0, input_centre_real_0 = sess.run(
                            [all_train_0[0], all_train_0[1], all_train_0[2]])

                        _, current_step_0, loss_0, accuracy_0, general_loss_adv, general_loss = sess.run(
                            [
                                train_op_0, global_step_0,
                                mtlmodel_0.total_loss, mtlmodel_0.acc,
                                base_model.loss_adv, mtlmodel_0.general_loss
                            ],
                            feed_dict={
                                base_model.input_task:
                                0,
                                base_model.input_left:
                                input_left_real_0,
                                base_model.input_right:
                                input_centre_real_0,
                                base_model.input_y:
                                input_y_real_0,
                                base_model.dropout_keep_prob:
                                FLAGS.dropout_keep_prob
                            })
                        train_acc_0 += accuracy_0
                        task_0_loss += general_loss
                        loss_adv_0 += general_loss_adv

                        input_y_real_1, input_left_real_1, input_centre_real_1 = sess.run(
                            [all_train_1[0], all_train_1[1], all_train_1[2]])
                        _, current_step_1, loss_1, accuracy_1, specfic_loss_adv, specfic_loss_diff, specfic_loss = sess.run(
                            [
                                train_op_1, global_step_1,
                                mtlmodel_1.total_loss, mtlmodel_1.acc,
                                base_model.loss_adv, mtlmodel_1.loss_diff,
                                mtlmodel_1.specfic_loss
                            ],
                            feed_dict={
                                base_model.input_task:
                                1,
                                base_model.input_left:
                                input_left_real_1,
                                base_model.input_right:
                                input_centre_real_1,
                                base_model.input_y:
                                input_y_real_1,
                                base_model.dropout_keep_prob:
                                FLAGS.dropout_keep_prob
                            })
                        train_acc_1 += accuracy_1
                        task_1_loss += specfic_loss
                        loss_adv_1 += specfic_loss_adv

                        if current_step_1 % 500 == 0:
                            print("step {}, loss {}, acc {}".format(
                                current_step_0, loss_0, accuracy_0))
                            print(
                                "----------specfic step {}, loss {}, acc {}------------"
                                .format(current_step_1, loss_1, accuracy_1))
                            sys.stdout.flush()

                        if current_step_1 % num_batches_per_epoch_train_0 == 0 or \
                                current_step_1 == num_batches_per_epoch_train_0 * FLAGS.num_epochs:
                            train_acc_0 /= num_batches_per_epoch_train_0
                            train_acc_1 /= num_batches_per_epoch_train_0
                            print(
                                "train_0: ",
                                current_step_0 / num_batches_per_epoch_train_0,
                                " epoch, task_0_loss: ", task_0_loss, "acc: ",
                                train_acc_0, "loss_adv_0: ", loss_adv_0)

                            print(
                                "train_1: ",
                                current_step_1 / num_batches_per_epoch_train_0,
                                " epoch, task_1_loss: ", task_1_loss, "acc: ",
                                train_acc_1, "loss_adv_1: ", loss_adv_1)

                            total_train_loss.append(task_1_loss)
                            train_acc_0 = 0
                            train_acc_1 = 0
                            task_1_loss = 0
                            task_0_loss = 0
                            loss_adv_0 = 0
                            loss_adv_1 = 0
                            sys.stdout.flush()

                            # continue
                            print(
                                "\n------------------Evaluation:-----------------------"
                            )
                            _, accuracy = dev_whole(num_batches_per_epoch_test)
                            dev_accuracy.append(accuracy)
                            print("--------Recently dev accuracy:--------")
                            print(dev_accuracy[-10:])

                            print("--------Recently train_loss:------")
                            print(total_train_loss[-10:])
                            if overfit(dev_accuracy):
                                print('-----Overfit!!----')
                                break
                            print("")
                            sys.stdout.flush()

                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=current_step_0)

                            print(
                                "-------------------Saved model checkpoint to {}--------------------"
                                .format(path))
                            sys.stdout.flush()
                            output_graph_def = tf.graph_util.convert_variables_to_constants(
                                sess,
                                sess.graph_def,
                                output_node_names=['mtl_1/MTLModel_1/prob'])
                            for node in output_graph_def.node:
                                if node.op == 'RefSwitch':
                                    node.op = 'Switch'
                                    for index in xrange(len(node.input)):
                                        if 'moving_' in node.input[index]:
                                            node.input[index] = node.input[
                                                index] + '/read'
                                elif node.op == 'AssignSub':
                                    node.op = 'Sub'
                                    if 'use_locking' in node.attr:
                                        del node.attr['use_locking']

                            with tf.gfile.GFile(
                                    FLAGS.train_dir +
                                    "runs/mtlmodel_specfic.pb", "wb") as f:
                                f.write(output_graph_def.SerializeToString())
                            print("%d ops in the final graph.\n" %
                                  len(output_graph_def.node))

            except tf.errors.OutOfRangeError:
                print("Done")
            finally:
                print(
                    "--------------------------finally---------------------------"
                )
                print("current_step:", current_step_0)
                coord.request_stop()
                coord.join(threads)

            sess.close()