Ejemplo n.º 1
0
def predict(path,modelpath):
    with tf.Graph().as_default():
        # Placeholders
        x = tf.placeholder(tf.float32, [1,227, 227, 3])
        dropout_keep_prob = tf.placeholder(tf.float32)
        imgs = []
        # path='/home/ugrad/Shang/animal/1_.jpg'
        # image = cv2.imread(path,0)

        # cv2.imwrite(path,img)
        img=cv2.imread(path)
        img = cv2.resize(img, (227, 227))
        img = img.astype(np.float32)
        imgs.append(img)
        # img=Image.open(path)
        # img = np.array(img)
        # img = tf.cast(img, tf.float32)
        # img = tf.reshape(img, [1, 227, 227, 3])

        # Model
        model = AlexNetModel(num_classes=FLAGS.num_classes, dropout_keep_prob=dropout_keep_prob)
        logits=model.inference(x)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # Directly restore (your model should be exactly the same with checkpoint)
            # Load the pretrained weights
            saver = tf.train.Saver(tf.global_variables())
            saver.restore(sess, modelpath)
            prediction = sess.run(logits, feed_dict={x: imgs,dropout_keep_prob: 1.})
            # print(prediction)
            max_index = np.argmax(prediction)
            print(max_index)
        return max_index
def main(_):
    # Placeholders
    x = tf.placeholder(tf.float32, [1, 227, 227, 3])
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    model = AlexNetModel(num_classes=FLAGS.num_classes, dropout_keep_prob=dropout_keep_prob)
    model.inference(x)

    saver = tf.train.Saver()


    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Directly restore (your model should be exactly the same with checkpoint)
        saver.restore(sess, FLAGS.ckpt)

        batch_x = np.ndarray([1, 227, 227, 3])

        # Read image and resize it
        img = cv2.imread(FLAGS.input_image)
        img = cv2.resize(img, (227, 227))
        img = img.astype(np.float32)

        # Subtract mean color
        img -= np.array([132.2766, 139.6506, 146.9702])

        batch_x[0] = img

        scores = sess.run(model.score, feed_dict={x: batch_x, dropout_keep_prob: 1.})
        print(scores)
Ejemplo n.º 3
0
def main(_):
    # Placeholders
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, 227, 227, 3])
    y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    model = AlexNetModel(num_classes=FLAGS.num_classes,
                         dropout_keep_prob=dropout_keep_prob)
    model.inference(x)

    # Accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    saver = tf.train.Saver()
    test_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.test_file,
                                          num_classes=FLAGS.num_classes,
                                          output_size=[227, 227])
    test_batches_per_epoch = np.floor(
        len(test_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Directly restore (your model should be exactly the same with checkpoint)
        saver.restore(sess, FLAGS.ckpt)

        test_acc = 0.
        test_count = 0

        for _ in range(test_batches_per_epoch):
            batch_tx, batch_ty = test_preprocessor.next_batch(FLAGS.batch_size)
            acc = sess.run(accuracy,
                           feed_dict={
                               x: batch_tx,
                               y: batch_ty,
                               dropout_keep_prob: 1.
                           })
            test_acc += acc
            test_count += 1

        test_acc /= test_count
        print("{} Test Accuracy = {:.4f}".format(datetime.datetime.now(),
                                                 test_acc))
Ejemplo n.º 4
0
def main(_):
    # Create training directories
    now = datetime.datetime.now()
    train_dir_name = now.strftime('alexnet_%Y%m%d_%H%M%S')
    train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name)
    checkpoint_dir = os.path.join(train_dir, 'checkpoint')
    tensorboard_dir = os.path.join(train_dir, 'tensorboard')
    tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
    tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')

    if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir)
    if not os.path.isdir(train_dir): os.mkdir(train_dir)
    if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
    if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir)
    if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir)
    if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir)

    # Write flags to txt
    flags_file_path = os.path.join(train_dir, 'flags.txt')
    flags_file = open(flags_file_path, 'w')
    flags_file.write('resume={}\n'.format(FLAGS.resume))
    flags_file.write('ckpt_path={}\n'.format(FLAGS.ckpt_path))
    flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate))
    flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob))
    flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs))
    flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
    flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale))
    flags_file.write('tensorboard_root_dir={}\n'.format(FLAGS.tensorboard_root_dir))
    flags_file.write('log_step={}\n'.format(FLAGS.log_step))
    flags_file.close()

    # Placeholders
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, 227, 227, 3])
    y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    model = AlexNetModel(num_classes=FLAGS.num_classes, dropout_keep_prob=dropout_keep_prob)
    loss = model.loss(x, y)
    train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)
#    train_op = model.optimize(FLAGS.learning_rate, train_layers)

    # Training accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Summaries
    tf.summary.scalar('train_loss', loss)
    tf.summary.scalar('train_accuracy', accuracy)
    merged_summary = tf.summary.merge_all()

    train_writer = tf.summary.FileWriter(tensorboard_train_dir)
    val_writer = tf.summary.FileWriter(tensorboard_val_dir)
    saver = tf.train.Saver()

    # Batch preprocessors
    multi_scale = FLAGS.multi_scale.split(',')
    if len(multi_scale) == 2:
        multi_scale = [int(multi_scale[0]), int(multi_scale[1])]
    else:
        multi_scale = None

    train_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes,
                                           output_size=[227, 227], horizontal_flip=True, shuffle=True, multi_scale=multi_scale)
    val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[227, 227])

    # Get the number of training/validation steps per epoch
    train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
    val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)


    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        train_writer.add_graph(sess.graph)

#        # Load the pretrained weights
#        model.load_original_weights(sess, skip_layers=train_layers)
#
#        # Directly restore (your model should be exactly the same with checkpoint)
        if FLAGS.resume:
            saver.restore(sess, FLAGS.ckpt_path)

        print("{} Start training...".format(datetime.datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir))

        for epoch in range(FLAGS.num_epochs):
            print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1))
            step = 1

            # Start training
            while step < train_batches_per_epoch:
                batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
                sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, dropout_keep_prob: FLAGS.dropout_keep_prob})

                # Logging
                if step % FLAGS.log_step == 0:
                    s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, dropout_keep_prob: 1.})
                    train_writer.add_summary(s, epoch * train_batches_per_epoch + step)

                step += 1

            # Epoch completed, start validation
            print("{} Start validation".format(datetime.datetime.now()))
            test_acc = 0.
            test_count = 0

            for _ in range(val_batches_per_epoch):
                batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
                acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, dropout_keep_prob: 1.})
                test_acc += acc
                test_count += 1

            test_acc /= test_count
            s = tf.Summary(value=[
                tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc)
            ])
            val_writer.add_summary(s, epoch+1)
            print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))

            # Reset the dataset pointers
            val_preprocessor.reset_pointer()
            train_preprocessor.reset_pointer()

            print("{} Saving checkpoint of model...".format(datetime.datetime.now()))

            #save checkpoint of the model
            checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch'+str(epoch+1)+'.ckpt')
            save_path = saver.save(sess, checkpoint_path)

            print("{} Model checkpoint saved at {}".format(datetime.datetime.now(), checkpoint_path))
Ejemplo n.º 5
0
import tensorflow as tf
import numpy as np
import sys
from model import AlexNetModel

# Edit just these
FILE_PATH = '/home/finetune/training/alexnet_20190220_005707/checkpoint/model_epoch7.ckpt'
NUM_CLASSES = 20
OUTPUT_FILE = 'sc_epoch7.npy'

if __name__ == '__main__':
    x = tf.placeholder(tf.float32, [128, 227, 227, 3])
    model = AlexNetModel(num_classes=NUM_CLASSES)
    model.inference(x)

    saver = tf.train.Saver()
    layers = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc8']
    data = {
        'conv1': [],
        'conv2': [],
        'conv3': [],
        'conv4': [],
        'conv5': [],
        'fc8': []
    }

    with tf.Session() as sess:
        saver.restore(sess, FILE_PATH)

        for op_name in layers:
            with tf.variable_scope(op_name, reuse=True):
Ejemplo n.º 6
0
def main(_):
    # Create training directories
    now = datetime.datetime.now()
    train_dir_name = now.strftime('ft_%Y%m%d_%H%M%S')
    train_dir = os.path.join(FLAGS.train_root_dir, train_dir_name)
    checkpoint_dir = os.path.join(train_dir, 'checkpoint')
    tensorboard_dir = os.path.join(train_dir, 'tensorboard')
    tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
    tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')

    if not os.path.isdir(FLAGS.train_root_dir):
        os.mkdir(FLAGS.train_root_dir)
    if not os.path.isdir(train_dir):
        os.mkdir(train_dir)
    if not os.path.isdir(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    if not os.path.isdir(tensorboard_dir):
        os.mkdir(tensorboard_dir)
    if not os.path.isdir(tensorboard_train_dir):
        os.mkdir(tensorboard_train_dir)
    if not os.path.isdir(tensorboard_val_dir):
        os.mkdir(tensorboard_val_dir)

    # Write flags to txt
    flags_file_path = os.path.join(train_dir, 'flags.txt')
    flags_file = open(flags_file_path, 'w')
    flags_file.write('model name: {}\n'.format(MODEL_NAME))
    flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate))
    flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob))
    flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs))
    flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
    flags_file.write('train_layers={}\n'.format(FLAGS.train_layers))
    flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale))
    flags_file.write('train_root_dir={}\n'.format(FLAGS.train_root_dir))
    flags_file.write('log_step={}\n'.format(FLAGS.log_step))
    flags_file.close()

    # Placeholders
    x = tf.placeholder(tf.float32, [None, 227, 227, 3], 'x')
    y = tf.placeholder(tf.float32, [None, NUM_CLASSES], 'y')
    decay_learning_rate = tf.placeholder(tf.float32)
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    train_layers = FLAGS.train_layers.split(',')
    model = AlexNetModel(num_classes=NUM_CLASSES, dropout_keep_prob=dropout_keep_prob)
    loss = model.get_loss(x, y)
    train_op = model.optimize(decay_learning_rate, train_layers)

    # Training accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
    correct = tf.reduce_sum(tf.cast(correct_pred, tf.float32))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Initialize the FileWriter
    train_writer = tf.summary.FileWriter(tensorboard_dir + '/train')
    test_writer = tf.summary.FileWriter(tensorboard_dir + '/val')

    # Summaries
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', accuracy)
    merged = tf.summary.merge_all()

    # Batch preprocessors
    multi_scale = FLAGS.multi_scale.split(',')
    if len(multi_scale) == 2:
        multi_scale = [int(multi_scale[0]), int(multi_scale[1])]
    else:
        multi_scale = None

    train_preprocessor = BatchPreprocessor(
        dataset_file_path=TRAINING_FILE,
        num_classes=NUM_CLASSES,
        output_size=[227, 227],
        horizontal_flip=True,
        shuffle=True,
        multi_scale=multi_scale)
    val_preprocessor = BatchPreprocessor(
        dataset_file_path=VAL_FILE,
        num_classes=NUM_CLASSES,
        output_size=[227, 227],
        multi_scale=multi_scale,
        istraining=False)

    # Initialize an saver for store model checkpoints
    saver = tf.train.Saver()

    # Get the number of training steps per epoch
    train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)

    with tf.Session() as sess:

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        # Add the model graph to TensorBoard
        train_writer.add_graph(sess.graph)

        # Load the pretrained weights
        model.load_original_weights(sess, skip_layers=train_layers)

        # Directly restore (your model should be exactly the same with checkpoint)
        # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt")

        logger.info("Start training...")
        logger.info("tensorboard --logdir {}".format(tensorboard_dir))
        global_step = 0

        for epoch in range(FLAGS.num_epochs):
            # Reset the dataset pointers
            train_preprocessor.reset_pointer()

            step = 1

            while step < train_batches_per_epoch:
                global_step += 1
                rate = decay(FLAGS.learning_rate, global_step, MAX_STEP)

                batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
                summary, loss, _ = sess.run(
                    [merged, model.loss, train_op],
                    feed_dict={
                        x: batch_xs,
                        decay_learning_rate: rate,
                        y: batch_ys,
                        dropout_keep_prob: 0.5
                    })
                train_writer.add_summary(summary, global_step)

                step += 1

                if global_step % 10 == 0:
                    logger.info("epoch {}, step {}, loss {:.6f}".format(epoch, global_step, loss))
                    test_acc = 0.
                    test_count = 0

                    for _ in range((len(val_preprocessor.labels))):
                        batch_tx, batch_ty = val_preprocessor.next_batch(1)
                        acc = sess.run(correct, feed_dict={x: batch_tx, y: batch_ty, dropout_keep_prob: 1.})
                        test_acc += acc
                        test_count += 1
                    test_acc_ = test_acc / test_count
                    s = tf.Summary(value=[tf.Summary.Value(tag="accuracy", simple_value=test_acc_)])
                    test_writer.add_summary(s, global_step)
                    logger.info("test accuracy: {:.4f}, {}/{}".format(test_acc_, test_acc, test_count))

                    # Reset the dataset pointers
                    val_preprocessor.reset_pointer()

                #save checkpoint of the model
                if global_step % 1000 == 0 and global_step > 0:
                    logger.info("saving checkpoint of model")
                    checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(global_step) + '.ckpt')
                    saver.save(sess, checkpoint_path)
Ejemplo n.º 7
0
def main(_):
    now = datetime.datetime.now()
    train_dir_name = now.strftime('alexnet_%Y%m%d_%H%M%S')
    train_dir = os.path.join(FLAGS.train_root_dir, train_dir_name)
    checkpoint_dir = os.path.join(train_dir, 'checkpoint')
    tensorboard_dir = os.path.join(train_dir, 'tensorboard')
    tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
    tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')

    if not os.path.isdir(FLAGS.train_root_dir): os.mkdir(FLAGS.train_root_dir)
    if not os.path.isdir(train_dir): os.mkdir(train_dir)
    if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
    if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir)
    if not os.path.isdir(tensorboard_train_dir):
        os.mkdir(tensorboard_train_dir)
    if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir)

    # Write flags to txt
    flags_file_path = os.path.join(train_dir, 'flags.txt')
    flags_file = open(flags_file_path, 'w')
    flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate))
    flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob))
    flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs))
    flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
    flags_file.write('train_layers={}\n'.format(FLAGS.train_layers))
    flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale))
    flags_file.write('train_root_dir={}\n'.format(FLAGS.train_root_dir))
    flags_file.write('log_step={}\n'.format(FLAGS.log_step))
    flags_file.close()
    # Placeholders
    x = tf.placeholder(tf.float32, [None, 227, 227, 3], 'x')
    xt = tf.placeholder(tf.float32, [None, 227, 227, 3], 'xt')
    y = tf.placeholder(tf.float32, [None, NUM_CLASSES], 'y')
    yt = tf.placeholder(tf.float32, [None, NUM_CLASSES], 'yt')
    adlamb = tf.placeholder(tf.float32)
    decay_learning_rate = tf.placeholder(tf.float32)
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    train_layers = FLAGS.train_layers.split(',')
    model = AlexNetModel(num_classes=NUM_CLASSES,
                         dropout_keep_prob=dropout_keep_prob)
    loss = model.loss(x, y)
    # Training accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
    correct = tf.reduce_sum(tf.cast(correct_pred, tf.float32))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    G_loss, D_loss, sc, tc = model.adloss(x, xt, y, adlamb)
    target_correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(yt, 1))
    target_correct = tf.reduce_sum(tf.cast(target_correct_pred, tf.float32))
    target_accuracy = tf.reduce_mean(tf.cast(target_correct_pred, tf.float32))
    train_op = model.optimize(decay_learning_rate, train_layers, adlamb, sc,
                              tc)

    # Testing accuracy of the model
    source_vector = model.fc8
    target_vector = model.vector
    target_pre = tf.argmax(model.score, 1)

    D_op = model.adoptimize(decay_learning_rate, train_layers)
    optimizer = tf.group(train_op, D_op)

    train_writer = tf.summary.FileWriter('./log/tensorboard' + MODEL_NAME)
    train_writer.add_graph(tf.get_default_graph())
    tf.summary.scalar('Testing Accuracy', target_accuracy)
    merged = tf.summary.merge_all()

    print '============================GLOBAL TRAINABLE VARIABLES ============================'
    print tf.trainable_variables()
    multi_scale = FLAGS.multi_scale.split(',')
    if len(multi_scale) == 2:
        multi_scale = [int(multi_scale[0]), int(multi_scale[1])]
    else:
        multi_scale = None
    print '==================== MULTI SCALE==================================================='
    print multi_scale
    train_preprocessor = BatchPreprocessor(dataset_file_path=TRAINING_FILE,
                                           num_classes=NUM_CLASSES,
                                           output_size=[227, 227],
                                           horizontal_flip=True,
                                           shuffle=True,
                                           multi_scale=multi_scale)
    Ttrain_preprocessor = BatchPreprocessor(dataset_file_path=VAL_FILE,
                                            num_classes=NUM_CLASSES,
                                            output_size=[227, 227],
                                            horizontal_flip=True,
                                            shuffle=True,
                                            multi_scale=multi_scale)
    val_preprocessor = BatchPreprocessor(dataset_file_path=VAL_FILE,
                                         num_classes=NUM_CLASSES,
                                         output_size=[227, 227],
                                         multi_scale=multi_scale,
                                         istraining=False)
    train_batches_per_epoch = np.floor(
        len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
    Ttrain_batches_per_epoch = np.floor(
        len(Ttrain_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
    val_batches_per_epoch = np.floor(
        len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)

    dic_s = {}
    dic_temp = {}
    dic_temp1 = {}
    dic_t = {}

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        train_writer.add_graph(sess.graph)
        model.load_original_weights(sess, skip_layers=train_layers)
        print("{} Start training...".format(datetime.datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(
            datetime.datetime.now(), tensorboard_dir))
        gs = 0
        gd = 0
        best_acc = 0.0
        flag = 1
        flag1 = 1
        first_s = 50
        for epoch in range(FLAGS.num_epochs):
            step = 1
            while step < train_batches_per_epoch:
                gd += 1
                lamb = adaptation_factor(gd * 1.0 / MAX_STEP)
                rate = decay(FLAGS.learning_rate, gd, MAX_STEP)
                for it in xrange(1):
                    gs += 1
                    if gs % Ttrain_batches_per_epoch == 0:
                        Ttrain_preprocessor.reset_pointer()
                    if gs % train_batches_per_epoch == 0:
                        train_preprocessor.reset_pointer()
                    batch_xs, batch_ys = train_preprocessor.next_batch(
                        FLAGS.batch_size)
                    Tbatch_xs, Tbatch_ys = Ttrain_preprocessor.next_batch(
                        FLAGS.batch_size)
                    summary, _ = sess.run(
                        [merged, optimizer],
                        feed_dict={
                            x: batch_xs,
                            xt: Tbatch_xs,
                            yt: Tbatch_ys,
                            adlamb: lamb,
                            decay_learning_rate: rate,
                            y: batch_ys,
                            dropout_keep_prob: 0.5
                        })
                    train_writer.add_summary(summary, gd)
                    closs, gloss, dloss, gregloss, dregloss, floss, smloss = sess.run(
                        [
                            model.loss, model.G_loss, model.D_loss,
                            model.Gregloss, model.Dregloss, model.F_loss,
                            model.Semanticloss
                        ],
                        feed_dict={
                            x: batch_xs,
                            xt: Tbatch_xs,
                            adlamb: lamb,
                            decay_learning_rate: rate,
                            y: batch_ys,
                            dropout_keep_prob: 0.5
                        })
                step += 1

                if epoch == first_s:
                    source_v = sess.run(source_vector,
                                        feed_dict={
                                            x: batch_xs,
                                            y: batch_ys,
                                            xt: Tbatch_xs,
                                            dropout_keep_prob: 1.
                                        })
                    for i in range(FLAGS.batch_size):
                        dic_temp.setdefault(np.argmax(batch_ys[i]),
                                            []).append(source_v[i])
                if epoch == first_s + 1 and flag == 1:
                    for i in dic_temp.keys():
                        dic_s[i] = np.mean(dic_temp[i], axis=0)
                    with open('dic_s.txt', 'w') as f:
                        f.write(str(dic_s))
                    flag = 0

#                if gd%50==0:
                if epoch % 5 == 0 and step == train_batches_per_epoch - 1:
                    print '=================== Step {0:<10} ================='.format(
                        gs)
                    print 'Epoch {0:<5} Step {1:<5} Closs {2:<10} Gloss {3:<10} Dloss {4:<10} Total_Loss {7:<10} Gregloss {5:<10} Dregloss {6:<10} Semloss {7:<10}'.format(
                        epoch, step, closs, gloss, dloss, gregloss, dregloss,
                        floss, smloss)
                    print 'lambda: ', lamb
                    print 'rate: ', rate
                    # Epoch completed, start validation
                    print("{} Start validation".format(
                        datetime.datetime.now()))
                    test_acc = 0.
                    test_count = 0
                    fp = open('pre_and_sim.txt', 'w')
                    for _ in range((len(val_preprocessor.labels))):
                        batch_tx, batch_ty = val_preprocessor.next_batch(1)

                        if flag == 0 and flag1 == 1:
                            target_v = sess.run(target_vector,
                                                feed_dict={
                                                    xt: batch_tx,
                                                    dropout_keep_prob: 1.
                                                })
                            sim_list = []
                            for j in range(NUM_CLASSES):
                                #                                print(target_v[0])
                                #                                print('okkk')
                                #                                print(dic_s[j])
                                sim_value = cos_distance(target_v[0], dic_s[j])
                                sim_list.append(sim_value)
                            max_sim = max(sim_list)
                            max_idx = sim_list.index(max_sim)
                            fp.write(str(max_idx) + ' ' + str(max_sim) + '\n')
                            dic_temp1.setdefault(np.argmax(batch_ty[0]),
                                                 []).append(target_v[0])
                        if epoch > first_s and flag1 == 1:
                            for i in dic_temp1.keys():
                                dic_t[i] = np.mean(dic_temp1[i], axis=0)
                            with open('dic_t.txt', 'w') as f:
                                f.write(str(dic_t))
                            flag1 = 0

                        acc = sess.run(correct,
                                       feed_dict={
                                           x: batch_tx,
                                           y: batch_ty,
                                           dropout_keep_prob: 1.
                                       })
                        test_acc += acc
                        test_count += 1
                    fp.close()
                    print test_acc, test_count
                    test_acc /= test_count
                    if test_acc > best_acc:
                        best_acc = test_acc
                    print('best acc is: %f' % best_acc)
                    print("{} Validation Accuracy = {:.4f}".format(
                        datetime.datetime.now(), test_acc))
                    # Reset the dataset pointers
                    val_preprocessor.reset_pointer()
                    #train_preprocessor.reset_pointer()
                if gd % 4000 == 0 and gd > 0:
                    saver.save(
                        sess,
                        './log/mstnmodel_' + MODEL_NAME + str(gd) + '.ckpt')
                    print("{} Saving checkpoint of model...".format(
                        datetime.datetime.now()))