def val():
    num_test = 10000
    data_dir = '/home/rong/something_for_deep/cifar-10-batches-bin'
    train_log_dir = './logs/train'

    image_batch, label_batch = input_data.read_cifar10(data_dir,
                                                       is_train=False,
                                                       batch_size=BATCH_SIZE,
                                                       shuffle=False)
    vgg16 = model.VGG16()
    logits = vgg16.build(image_batch, NUM_CLASSES, False)
    saver = tf.train.Saver()
    correct_per_batch = tools.num_correct_prediction(logits, label_batch)

    with tf.Session() as sess:
        print('Reading checkpoints')
        ckpt = tf.train.get_checkpoint_state(train_log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('Loading success, global_step is %s' % global_step)
        else:
            print('No checkpoint file found')
            return

        saver.restore(sess, './logs/train/model.ckpt-8000')
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            print('\nEvaluating......')
            num_step = int(math.floor(num_test / BATCH_SIZE))
            num_sample = num_step * BATCH_SIZE
            step = 0
            total_correct = 0
            while step < num_step and not coord.should_stop():
                batch_correct = sess.run(correct_per_batch)
                total_correct += np.sum(batch_correct)
                step += 1
                if step % 10 == 0:
                    print('Testing samples: %d' % (step * BATCH_SIZE))
                    print('Correct predictions: %d' % total_correct)
                    print('Average accuracy: %.2f%%' % (total_correct /
                                                        (step * BATCH_SIZE)))
            print('Total testing samples: %d' % num_sample)
            print('Total correct predictions: %d' % total_correct)
            print('Average accuracy: %.2f%%' %
                  (100 * total_correct / num_sample))
        except Exception as e:
            coord.request_stop(e)
        finally:
            coord.request_stop()
            coord.join(threads)
Esempio n. 2
0
def tower_loss(scope):
    images, labels = input_data.read_cifar10(FLAGS.data_dir, True,
                                             FLAGS.batch_size, True)
    logits = resnet.inference(images, FLAGS.num_units_per_block,
                              FLAGS.is_training)
    _ = resnet.loss(logits, labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    with tf.name_scope(None) as scope:
        tf.summary.scalar("total_loss", total_loss)
    return total_loss
def evaluate():
    with tf.Graph().as_default():

        #        log_dir = 'C://Users//kevin//Documents//tensorflow//VGG//logsvgg//train//'
        #        log_dir = 'C:/Users/kevin/Documents/tensorflow/VGG/logs/train/'
        log_dir = 'C:/Users/YuanGuoliang/Desktop/3_5th/VGG_model1/logs/train/'
        test_dir = './/data//cifar-10-batches-bin//'
        n_test = 10000

        images, labels = input_data.read_cifar10(data_dir=test_dir,
                                                 is_train=False,
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=False)

        logits = VGG.VGG16N(images, N_CLASSES, IS_PRETRAIN)
        correct = tools.num_correct_prediction(logits, labels)
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:

            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')
                return

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:
                print('\nEvaluating......')
                num_step = int(math.floor(n_test / BATCH_SIZE))
                num_sample = num_step * BATCH_SIZE
                step = 0
                total_correct = 0
                while step < num_step and not coord.should_stop():
                    batch_correct = sess.run(correct)
                    total_correct += np.sum(batch_correct)
                    step += 1
                print('Total testing samples: %d' % num_sample)
                print('Total correct predictions: %d' % total_correct)
                print('Average accuracy: %.2f%%' %
                      (100 * total_correct / num_sample))
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)
Esempio n. 4
0
def evaluate():
    with tf.Graph().as_default():

        log_dir = './logs2/train/'
        test_dir = '/content/data/'
        n_test = 10000

        test_iamge_batch, test_label_batch = input_data.read_cifar10(
            test_dir, is_train=False, batch_size=BATCH_SIZE, shuffle=False)

        logits = VGG.MyResNet(test_iamge_batch, N_CLASSES, IS_PRETRAIN)
        correct = tools.num_correct_prediction(logits, test_label_batch)
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:

            print('Reading checkpoint...')
            ckpt = tf.train.get_checkpoint_state(log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Load success, global step: %s' % global_step)
            else:
                print('No checkpoint file found')
                return

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:
                print('\nEvaluating...')
                num_step = int(math.ceil(n_test / BATCH_SIZE))
                num_example = num_step * BATCH_SIZE
                step = 0
                total_correct = 0
                while step < num_step and not coord.should_stop():
                    batch_correct = sess.run(correct)
                    total_correct += np.sum(batch_correct)
                    step += 1

                print("Total test examples: %d" % num_example)
                print("Total correct predictions: %d" % total_correct)
                print("Average accuracy: %.2f%%" %
                      (100 * total_correct / num_example))
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)
def evaluate():
    with tf.Graph().as_default():
        
#        log_dir = 'C://Users//kevin//Documents//tensorflow//VGG//logsvgg//train//'
        log_dir = 'C:/Users/kevin/Documents/tensorflow/VGG/logs/train/'
        test_dir = './/data//cifar-10-batches-bin//'
        n_test = 10000
                
        images, labels = input_data.read_cifar10(data_dir=test_dir,
                                                    is_train=False,
                                                    batch_size= BATCH_SIZE,
                                                    shuffle=False)

        logits = VGG.VGG16N(images, N_CLASSES, IS_PRETRAIN)
        correct = tools.num_correct_prediction(logits, labels)
        saver = tf.train.Saver(tf.global_variables())
        
        with tf.Session() as sess:
            
            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')
                return
        
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess = sess, coord = coord)
            
            try:
                print('\nEvaluating......')
                num_step = int(math.floor(n_test / BATCH_SIZE))
                num_sample = num_step*BATCH_SIZE
                step = 0
                total_correct = 0
                while step < num_step and not coord.should_stop():
                    batch_correct = sess.run(correct)
                    total_correct += np.sum(batch_correct)
                    step += 1
                print('Total testing samples: %d' %num_sample)
                print('Total correct predictions: %d' %total_correct)
                print('Average accuracy: %.2f%%' %(100*total_correct/num_sample))
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)
Esempio n. 6
0
def evaluate():
    with tf.Graph().as_default():
        
        log_dir = './logs/vgg16_logs/train/'
        data_dir = './data/cifar-10-batches-bin/'
        n_test = 10000
                
        images, labels = input_data.read_cifar10(data_dir=data_dir,
                                                 is_train=False,
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=False)

        logits = VGG.VGG16N(images, N_CLASSES, IS_PRETRAIN)  # shape of logits: [Batch_size, n_classes]
        correct = tools.num_correct_prediction(logits, labels)  # 得到准确率,类型是浮点型
        saver = tf.train.Saver(tf.global_variables())
        
        with tf.Session() as sess:
            
            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')
                return
        
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            
            try:
                print('\nEvaluating......')
                # math.floor()返回小于或等于一个给定数字的最大整数
                num_step = int(math.floor(n_test / BATCH_SIZE))  # num_step = 312
                num_sample = num_step*BATCH_SIZE   # num_sample=9984
                step = 0
                total_correct = 0
                while step < num_step and not coord.should_stop():
                    batch_correct = sess.run(correct)    # 得到在一个batch中正确的个数
                    total_correct += np.sum(batch_correct)  # 得到总共的正确数量
                    step += 1
                print('Total testing samples: %d' % num_sample)
                print('Total correct predictions: %d' % total_correct)
                print('Average accuracy: %.2f%%' % (100*total_correct/num_sample))
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)
Esempio n. 7
0
def test():
    with tf.Graph().as_default():
        n_test = 10000

        images, labels = input_data.read_cifar10(data_dir=data_dir,
                                                 is_train=False,
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=False)

        logits = models.VGG16(images, N_CLASSES, IS_PRETRAIN)

        correct = utils.num_correct_prediction(logits, labels)
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:
            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(train_log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found!')
                return

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:
                print('Testing......')
                num_step = int(math.floor(n_test / BATCH_SIZE))
                num_sample = num_step * BATCH_SIZE
                step = 0
                total_correct = 0
                while step < num_step and not coord.should_stop():
                    batch_correct = sess.run(correct)
                    total_correct += np.sum(batch_correct)
                    step += 1
                print('Total testing samples: %d' % num_sample)
                print('Total correct predictions: %d' % total_correct)
                print('Average accuracy: %.2f%%' %
                      (100 * total_correct / num_sample))
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)
Esempio n. 8
0
def evaluate(batch_size):
    with tf.Graph().as_default():
        log_dir = 'F:\\DL\\mxnet_cifar10\\trian_logs'
        test_img_dir = 'F:\\DL\\mxnet_cifar10\\data\\cifar-10-batches-py'
        n_test = 20000
        images, labels = input_data.read_cifar10(data_dir=test_img_dir,
                                                 is_train=False,
                                                 batch_size=batch_size,
                                                 shuffle=False)
        logits, is_trian = resnet.resnet18(images, n_claases=10)
        correct = num_correct_prediction(logits, labels)
        saver = tf.train.Saver(tf.global_variables())
        with tf.Session() as sess:
            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print("No checkpoint file found")
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                num_step = int(math.floor(n_test / batch_size))
                num_sample = num_step * batch_size
                step = 0
                total_correct = 0
                while step < num_step and not coord.should_stop():
                    batch_correct = sess.run(correct,
                                             feed_dict={is_trian: False})
                    total_correct += np.sum(batch_correct)
                    step += 1
                print('Total testing samples: %d' % num_sample)
                print('Total correct predictions: %d' % total_correct)
                print('Average accuracy: %.2f%%' %
                      (100 * total_correct / num_sample))
            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
            coord.join(threads)
Esempio n. 9
0
def evaluate():
    with tf.Graph().as_default():
        test_images, test_labels = input_data.read_cifar10(
            FLAGS.DATA_DIR, False, FLAGS.BATCH_SIZE, False)
        logits = resnet.inference(test_images, 7, is_training=False)
        init = tf.global_variables_initializer()
        acc_num = num_correct_predition(logits, test_labels)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('reading checkpoints...')
            checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
            if checkpoint:
                saver.restore(sess, checkpoint)
                print("restore from the checkpoint {0}".format(checkpoint))
            else:
                print('No checkpoint file found')
            try:
                print('\nEvaluating...')
                num_step = int(math.floor(FLAGS.num_test / FLAGS.BATCH_SIZE))
                num_sample = num_step * FLAGS.BATCH_SIZE
                step = 0
                total_correct = 0
                print(num_step, num_sample)
                while step < num_step:
                    batch_correct = sess.run(acc_num)
                    total_correct += np.sum(batch_correct)
                    step += 1
                    print(step)
                print('Total testing samples: %d' % num_sample)
                print('Total correct predictions: %d' % total_correct)
                print('Average accuracy: %.2f%%' %
                      (100 * total_correct / num_sample))
            except Exception as e:
                pass
            finally:
                coord.request_stop()
                coord.join(threads)
def train():
    
    pre_trained_weights = './/vgg16_pretrain//vgg16.npy'
    data_dir = './/data//cifar-10-batches-bin//'
    train_log_dir = './/logs//train//'
    val_log_dir = './/logs//val//'
    
    with tf.name_scope('input'):
        tra_image_batch, tra_label_batch = input_data.read_cifar10(data_dir=data_dir,
                                                 is_train=True,
                                                 batch_size= BATCH_SIZE,
                                                 shuffle=True)
        val_image_batch, val_label_batch = input_data.read_cifar10(data_dir=data_dir,
                                                 is_train=False,
                                                 batch_size= BATCH_SIZE,
                                                 shuffle=False)
    
    logits = VGG.VGG16N(tra_image_batch, N_CLASSES, IS_PRETRAIN)
    loss = tools.loss(logits, tra_label_batch)
    accuracy = tools.accuracy(logits, tra_label_batch)
    my_global_step = tf.Variable(0, name='global_step', trainable=False) 
    train_op = tools.optimize(loss, learning_rate, my_global_step)
    
    x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_W, IMG_H, 3])
    y_ = tf.placeholder(tf.int16, shape=[BATCH_SIZE, N_CLASSES])    
    
    saver = tf.train.Saver(tf.global_variables())
    summary_op = tf.summary.merge_all()   
       
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    
    # load the parameter file, assign the parameters, skip the specific layers
    tools.load_with_skip(pre_trained_weights, sess, ['fc6','fc7','fc8'])   


    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)    
    tra_summary_writer = tf.summary.FileWriter(train_log_dir, sess.graph)
    val_summary_writer = tf.summary.FileWriter(val_log_dir, sess.graph)
    
    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                    break
                
            tra_images,tra_labels = sess.run([tra_image_batch, tra_label_batch])
            _, tra_loss, tra_acc = sess.run([train_op, loss, accuracy],
                                            feed_dict={x:tra_images, y_:tra_labels})            
            if step % 50 == 0 or (step + 1) == MAX_STEP:                 
                print ('Step: %d, loss: %.4f, accuracy: %.4f%%' % (step, tra_loss, tra_acc))
                summary_str = sess.run(summary_op)
                tra_summary_writer.add_summary(summary_str, step)
                
            if step % 200 == 0 or (step + 1) == MAX_STEP:
                val_images, val_labels = sess.run([val_image_batch, val_label_batch])
                val_loss, val_acc = sess.run([loss, accuracy],
                                             feed_dict={x:val_images,y_:val_labels})
                print('**  Step %d, val loss = %.2f, val accuracy = %.2f%%  **' %(step, val_loss, val_acc))

                summary_str = sess.run(summary_op)
                val_summary_writer.add_summary(summary_str, step)
                    
            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(train_log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()
        
    coord.join(threads)
    sess.close()
Esempio n. 11
0
def train():
    #with tf.device('/device:GPU:0'):
    with tf.Graph().as_default():

        #        log_dir = 'C://Users//kevin//Documents//tensorflow//VGG//logsvgg//train//'
        log_dir = './vgg/logs/train/'

        #input where you put your binary cifar10 image
        Data_dir = '/home/tony/Desktop/Datasets/cifar_10_batches_binary_version/'

        train_log_dir = './devise_logs/train/'
        val_log_dir = './devise_logs/val/'
        #n_test the number of image data
        n_test = 10000

        with tf.name_scope('input'):
            tra_image_batch, tra_label_batch = input_data.read_cifar10(
                data_dir=Data_dir,
                is_train=True,
                batch_size=BATCH_SIZE,
                shuffle=True)
            val_image_batch, val_label_batch = input_data.read_cifar10(
                data_dir=Data_dir,
                is_train=False,
                batch_size=BATCH_SIZE,
                shuffle=False)

        #images
        x = tf.placeholder(tf.float32,
                           shape=[BATCH_SIZE, IMG_W, IMG_H, 3],
                           name='input_x')

        #labels
        y_ = tf.placeholder(tf.int16,
                            shape=[BATCH_SIZE, N_CLASSES],
                            name='input_y_')

        logits = VGG.VGG16N(x, N_CLASSES, is_pretrain=False)
        saver = tf.train.Saver(tf.global_variables())

        id_array_tensor = tf.placeholder(tf.int32, [N_CLASSES])

        softmax_accuracy = tools.accuracy(logits, y_)
        vgg_predict = tf.argmax(logits, -1)
        vgg_predict_id = tf.gather_nd(id_array_tensor, vgg_predict)

        get_label_string_tensor = tf.placeholder(
            tf.float32, shape=[None, word_vector_dim],
            name="label_string")  ##word_vector_dim

        tmp = tf.matmul(tf.cast(y_, tf.float32), get_label_string_tensor)
        #print(tmp.shape)

        initializer = tf.contrib.layers.variance_scaling_initializer()
        fc7 = tf.get_default_graph().get_tensor_by_name("VGG16/fc7/Relu:0")

        fc8 = tf.layers.dense(inputs=fc7,
                              units=1024,
                              kernel_initializer=initializer,
                              name="combination_hidden1")
        image_feature_output = tf.layers.dense(inputs=fc8,
                                               units=word_vector_dim,
                                               kernel_initializer=initializer,
                                               name="combination_hidden2")

        #devise_loss

        tmparray_tensor = tf.placeholder(tf.float32,
                                         shape=[None, word_vector_dim],
                                         name='tmparray')

        margin = 0.1
        #tMV here means that max (tJ *M* V - tLabel *M *V ,0 ) in essay
        #tmparray_tensor mearns tJ     and tmp means tmp
        tMV = tf.nn.relu(margin + tf.matmul(
            (tmparray_tensor -
             tmp), tf.transpose(tf.cast(image_feature_output, tf.float32))))
        hinge_loss = tf.reduce_mean(tf.reduce_sum(tMV, 0), name='hinge_loss')

        train_step1 = tf.train.AdamOptimizer(
            0.0001, name="optimizer").minimize(hinge_loss)

        #tMV here means that tJ *M* V in essay
        tMV_ = tf.matmul(
            tmparray_tensor,
            tf.transpose(tf.cast(image_feature_output, tf.float32)))

        #accuracy
        predict_label = tf.argmax(tMV_, 0)
        predict_label = tf.cast(predict_label, tf.int32)
        predict_label = tf.reshape(predict_label, [-1, 1],
                                   name='predict_label_text')

        #id_array_tensor = tf.placeholder(tf.int32 ,[N_CLASSES])
        select_id = tf.cast(tf.argmax(input=y_, axis=-1), tf.int32)
        select_id = tf.reshape(select_id, [1])
        y_label = tf.gather_nd(id_array_tensor, select_id)
        y_label = tf.reshape(y_label, [-1, 1], name='true_label_text')

        #y_label = tf.argmax(tf.matmul(tmparray_tensor , tf.transpose(tmp)), 0) #(2000,word_vector_dim)*(word_vector_dim,1)
        #y_label = tf.reshape(y_label,[-1,1])

        print(y_label.shape)
        print(predict_label.shape)

        acc, acc_op = tf.metrics.accuracy(labels=y_label,
                                          predictions=predict_label,
                                          weights=None,
                                          metrics_collections=None,
                                          updates_collections=None,
                                          name="acc")

        summary_op = tf.summary.merge_all()

        saver2 = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:
            tra_summary_writer = tf.summary.FileWriter(train_log_dir,
                                                       sess.graph)
            val_summary_writer = tf.summary.FileWriter(val_log_dir, sess.graph)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                #saver.restore(sess, ckpt.model_checkpoint_path)
                saver.restore(
                    sess,
                    "/home/tony/Desktop/DeVise/vgg/logs/train/model.ckpt-14999"
                )
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')
                return
##---------------------------------------------------------------Training-------------------------------------------------------------------------------------------------
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('\Triaining....')
            try:
                vgg_total_correct = 0
                for step in tqdm(range(MAX_STEP)):
                    if coord.should_stop():
                        break
                    #print("step %d"%step)
                    tra_images, tra_labels = sess.run(
                        [tra_image_batch, tra_label_batch])

                    loss, _, accuracy_, acc_operator, predict_label_, y_label_, summary_str, vgg_correct, vgg_predict_id_ = sess.run(
                        [
                            hinge_loss, train_step1, acc, acc_op,
                            predict_label, y_label, summary_op,
                            softmax_accuracy, vgg_predict_id
                        ],
                        feed_dict={
                            get_label_string_tensor: label_string_vector,
                            tmparray_tensor: tmparray,
                            id_array_tensor: id_array,
                            x: tra_images,
                            y_: tra_labels
                        })

                    #print(vgg_correct)
                    if vgg_correct > 50:
                        vgg_total_correct = vgg_total_correct + 1

                    if step % 100 == 0:
                        print("step %d" % step)
                        print('%d / %d steps' % (step, MAX_STEP), 'loss = ',
                              loss, '    acc = ', accuracy_, '\n\n')
                        print('vgg predict acc  ',
                              vgg_total_correct * 1.0 / (step + 1),
                              ' ---->vgg predict     ',
                              keys[int(vgg_predict_id_)])
                        print('           devise predict_label',
                              predict_label_, ' ---->DeVise predict  ',
                              keys[int(predict_label_)])
                        print('                y_label             ', y_label_,
                              ' ---->ground  truth is', keys[int(y_label_)],
                              '\n\n-------\n\n')

                    if step % 2000 == 0 or (step + 1) == MAX_STEP:
                        checkpoint_path = os.path.join(train_log_dir,
                                                       'model.ckpt')
                        saver2.save(sess, checkpoint_path, global_step=step)

            except tf.errors.OutOfRangeError:
                print('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
                coord.join(threads)

        print("end training\n\n\n------------------------------------------\n")
        ##-----------------------------------------------------------------Testing---------------------------------------
        with tf.Session() as sess:
            print('----Testing----')
            #input num of data

            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(train_log_dir)
            print(ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                #saver.restore(sess, ckpt.model_checkpoint_path)
                print(global_step)
                print(ckpt.model_checkpoint_path)
                saver2.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')
                return
            sess.run(tf.local_variables_initializer())

            num_of_test = MAX_STEP
            get_acc = -1
            try:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                for step in tqdm(range(MAX_STEP)):
                    if coord.should_stop():
                        break
                    #print("step %d"%step)
                    val_images, val_labels = sess.run(
                        [val_image_batch, val_label_batch])

                    loss, accuracy_, acc_operator, predict_label_, y_label_ = sess.run(
                        [hinge_loss, acc, acc_op, predict_label, y_label],
                        feed_dict={
                            get_label_string_tensor: label_string_vector,
                            tmparray_tensor: tmparray,
                            id_array_tensor: id_array,
                            x: val_images,
                            y_: val_labels
                        })

                    #print('%d / %d steps'%(step,MAX_STEP),'loss = ',loss,'    acc = ',accuracy_,'\n\n')
                    #print ('predict_label',predict_label_,' ----> I predict ',keys[int(predict_label_)])
                    #print ('y_label      ',y_label_,      ' ----> true ans',keys[int(y_label_)],'\n\n-------\n\n')
                    get_acc = accuracy_

            except tf.errors.OutOfRangeError:
                print('Done test -- epoch limit reached')
            finally:
                print('test acc = ', get_acc)
                coord.request_stop()
                coord.join(threads)
Esempio n. 12
0
def train():
    with tf.name_scope('inputs'):
        train_images, train_labels = input_data.read_cifar10(
            FLAGS.data_dir, True, FLAGS.batch_size, True)
        test_images, test_labels = input_data.read_cifar10(
            FLAGS.data_dir, False, FLAGS.batch_size, False)
    xs = tf.placeholder(dtype=tf.float32, shape=(FLAGS.batch_size, 32, 32, 3))
    ys = tf.placeholder(dtype=tf.int32, shape=(FLAGS.batch_size, 10))
    global_step = tf.Variable(0, trainable=False)
    lerning_rate = tf.train.exponential_decay(FLAGS.lr,
                                              global_step,
                                              32000,
                                              0.1,
                                              staircase=False)
    tf.summary.scalar('lerning_rate', lerning_rate)

    logits = resnet.inference(xs, FLAGS.num_units_per_block, FLAGS.is_training)
    loss = resnet.loss(logits, ys)
    tf.summary.scalar('loss', loss)
    opt = tf.train.MomentumOptimizer(lerning_rate, 0.9)
    update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_op):
        train = opt.minimize(loss, global_step=global_step)

    acc_op = resnet.accurracy(logits, ys)
    tf.summary.scalar('accuracy', acc_op)
    err_op = resnet.error(logits, ys)
    tf.summary.scalar('error', err_op)

    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver(tf.all_variables())
    init = tf.global_variables_initializer()
    coord = tf.train.Coordinator()

    with tf.Session() as sess:
        sess.run(init)
        train_summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                                     sess.graph)
        test_summary_writer = tf.summary.FileWriter(FLAGS.test_dir, sess.graph)
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        start_step = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print("restore from the checkpoint {0}".format(checkpoint))
            start_step += int(checkpoint.split('-')[-1])
        print("start training...")
        try:
            for step in range(start_step, FLAGS.max_steps):
                if coord.should_stop():
                    break
                tra_images_batch, tra_labels_batch = sess.run(
                    [train_images, train_labels])
                tes_images_batch, tes_labels_batch = sess.run(
                    [test_images, test_labels])
                _ = sess.run(train,
                             feed_dict={
                                 xs: tra_images_batch,
                                 ys: tra_labels_batch
                             })
                if step % 50 == 0 or (step + 1) == FLAGS.max_steps:
                    tra_los, tra_acc = sess.run([loss, acc_op],
                                                feed_dict={
                                                    xs: tra_images_batch,
                                                    ys: tra_labels_batch
                                                })
                    print('Step: %d, loss: %.6f, accuracy: %.4f' %
                          (step, tra_los, tra_acc))
                if step % 200 == 0 or (step + 1) == FLAGS.max_steps:
                    tes_los, tes_acc = sess.run([loss, acc_op],
                                                feed_dict={
                                                    xs: tes_images_batch,
                                                    ys: tes_labels_batch
                                                })
                    print(
                        '***test_loss***Step: %d, loss: %.6f, accuracy: %.4f' %
                        (step, tes_los, tes_acc))
                if step % 300 == 0 or (step + 1) == FLAGS.max_steps:
                    summary_str1 = sess.run(summary_op,
                                            feed_dict={
                                                xs: tra_images_batch,
                                                ys: tra_labels_batch
                                            })
                    summary_str2 = sess.run(summary_op,
                                            feed_dict={
                                                xs: tes_images_batch,
                                                ys: tes_labels_batch
                                            })
                    train_summary_writer.add_summary(summary_str1, step)
                    test_summary_writer.add_summary(summary_str2, step)
                if step % 2000 == 0 or (step + 1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
            coord.request_stop()
            coord.join()
        finally:
            coord.request_stop()
            coord.join(threads)
Esempio n. 13
0
def train(lr, batch_size, max_step, n_classes):
    train_log_dir = 'F:\\DL\\mxnet_cifar10\\trian_logs'
    data_dir = 'F:\\DL\\mxnet_cifar10\\data\\cifar-10-batches-py'
    with tf.name_scope('input'):
        train_img_batch, train_label_batch = input_data.read_cifar10(
            data_dir, is_train=True, batch_size=batch_size, shuffle=True)
        val_img_batch, val_label_batch = input_data.read_cifar10(
            data_dir, is_train=False, batch_size=batch_size, shuffle=False)
    features = tf.placeholder(tf.float32, [batch_size, 32, 32, 3])
    labels = tf.placeholder(tf.int64, [batch_size, 10])
    logits, is_train = resnet.resnet18(features, n_classes)
    cross_entropy = loss(logits, labels)
    acc = accuracy(logits, labels)
    with tf.name_scope('adam_optimizer'):
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_step = tf.train.AdamOptimizer(lr).minimize(cross_entropy)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            for step in range(max_step):
                if coord.should_stop():
                    break

                tra_img, tra_label = sess.run(
                    [train_img_batch, train_label_batch])
                _, tra_loss, tra_acc = sess.run(
                    [train_step, cross_entropy, acc],
                    feed_dict={
                        features: tra_img,
                        labels: tra_label,
                        is_train: True
                    })

                if step % 50 == 0 or (step + 1) == max_step:
                    print('Step: %d, loss: %.4f, accuracy: %.4f%%' %
                          (step, tra_loss, tra_acc))

                if step % 200 == 0 or (step + 1) == max_step:
                    val_img, val_label = sess.run(
                        [val_img_batch, val_label_batch])
                    val_loss, val_acc = sess.run([cross_entropy, acc],
                                                 feed_dict={
                                                     features: val_img,
                                                     labels: val_label,
                                                     is_train: False
                                                 })
                    print(
                        '**  Step %d, val loss = %.2f, val accuracy = %.2f%%  **'
                        % (step, val_loss, val_acc))
                if step % 2000 == 0 or (step + 1) == max_step:
                    checkpoint_path = os.path.join(train_log_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)
Esempio n. 14
0
def load_mnist(dset, n_labeled=-1, data_root=os.path.join('..', 'data')):
    data_path = os.path.join(data_root, dset)
    if dset=='digits': return input_data.read_mnist(data_path, one_hot=True, SOURCE_URL=input_data.SOURCE_DIGITS, n_labeled=n_labeled)
    if dset=='fashion': return input_data.read_mnist(data_path, one_hot=True, SOURCE_URL=input_data.SOURCE_FASHION, n_labeled=n_labeled)
    if dset=='fashion_2d': return input_data.read_mnist(os.path.join(data_root, 'fashion'), one_hot=True, SOURCE_URL=input_data.SOURCE_FASHION, n_labeled=n_labeled, binary_zero=[0,2,3,4,6])
    if dset=='cifar10': return input_data.read_cifar10(data_path)
Esempio n. 15
0
def evaluate1():

    with tf.Graph().as_default():
        log_dir = './logs/train_st2/'  # 'C:/3_5th/VGG_model/logs/train/'  #训练日志,即训练参数
        # test_dir = './/cifar10_data//cifar-10-batches-bin//'
        test_dir = './data'
        n_test = 3000

        input_images, input_labels, input_labels1 = input_data.read_cifar10(
            data_dir=test_dir,
            is_train=True,
            batch_size=BATCH_SIZE,
            shuffle=True)
        # mean_data = np.mean(mnist.train.images, axis=0)
        logits, features = VGG.VGG16N(input_images, N_CLASSES, IS_PRETRAIN)
        correct = tools.num_correct_prediction(logits, input_labels)
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:

            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(log_dir)
            if ckpt and ckpt.model_checkpoint_path:
                print("找到文件啦")
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print("没有找到文件")
                print('No checkpoint file found')
                return

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:
                print('\nEvaluating......')
                # j=0
                d = []
                tt = np.array([])

                for step in np.arange(1):
                    if coord.should_stop():
                        break
                    feat = np.array([])
                    images, labels = sess.run([input_images, input_labels])
                    mean_data = np.mean(images, axis=0)
                    # label=[]
                    # np.concatenate((array,q),axis=0)
                    b = np.transpose(np.nonzero(labels))[:, 1]
                    # d = b
                    d = np.concatenate((d, b), axis=0)
                    # labels=labels.tolist()
                    # print(d)
                    if step == 0:
                        tt = sess.run(
                            features,
                            feed_dict={input_images: images - mean_data})
                    else:
                        feat = sess.run(
                            features,
                            feed_dict={input_images: images - mean_data})

                    if step != 0:
                        tt = np.concatenate([tt, feat])

                fig = plt.figure(figsize=(16, 9))
                # fig = plt.figure()
                # ax = Axes3D(fig)
                #aa = TSNE(n_components=2).fit_transform(tt)
                pca = PCA(n_components=2)
                pca.fit(tt)
                aa = pca.transform(tt)
                #io.savemat('zongtnse.mat', {'matrix': aa})
                #lda = LinearDiscriminantAnalysis(n_components=2)
                #lda.fit(tt, d)
                #aa = lda.transform(tt)
                np.save('save_pca', aa)
                #aa = TSNE(n_components=2).fit_transform(tt)

                #print(aa[d==0,0].flatten())
                #np.save('vgg-9', aa)
                # ax.scatter(aa[:,0],aa[:,1],aa[:,2],c=labels1)
                # f = plt.figure()
                c = [
                    '#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
                    '#ff00ff', '#990000', '#999900', '#009900'
                ]
                for i in range(9):
                    plt.plot((aa[d == i, 0].flatten()) / 100.0,
                             (aa[d == i, 1].flatten()) / 100.0,
                             '.',
                             markersize=10,
                             c=c[i])

                plt.legend(['1', '2', '3', '4', '5', '6', '7', '8', '9'])
                # plt.xlim(-10, 10)
                # plt.ylim(-10,10)
                plt.grid()
                plt.show()
                # plt.close(fig)

            except Exception as e:
                coord.request_stop(e)
            finally:
                coord.request_stop()
                coord.join(threads)
def train2(retrain=False):
    data_dir = '/home/rong/something_for_deep/cifar-10-batches-bin'

    train_image_batch, train_label_batch = input_data.read_cifar10(
        data_dir=data_dir, is_train=True, batch_size=BATCH_SIZE, shuffle=True)

    # 宣布图片batch和标签batch的占位符
    x = tf.placeholder(tf.float32,
                       shape=[BATCH_SIZE, IMG_W, IMG_H, IMG_CHANNELS],
                       name='X')
    y_ = tf.placeholder(tf.int16, shape=[BATCH_SIZE, NUM_CLASSES])

    with open('vgg_6000.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        logits = tf.import_graph_def(graph_def,
                                     input_map={'X': x},
                                     return_elements=['fc8/relu:0'])

    # 宣布损失,精确度等关键节点
    loss = tools.loss(logits, y_)
    accuracy = tools.accuracy(logits, y_)

    sess = tf.Session()

    init = tf.global_variables_initializer()
    sess.run(init)

    coord = tf.train.Coordinator()  #宣布线程管理器
    threads = tf.train.start_queue_runners(
        sess=sess, coord=coord)  #线程负责把文件加入队列(input_data那个file队列)

    train_images, train_labels = sess.run(
        [train_image_batch, train_label_batch])
    loss2, accuracy = sess.run([loss, accuracy],
                               feed_dict={
                                   x: train_images,
                                   y_: train_labels
                               })
    print(loss2, accuracy)

    coord.request_stop()
    coord.join(threads)
    '''
    if retrain == False:
        print('Reading checkpoints')
        ckpt = tf.train.get_checkpoint_state(train_log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            saver.restore(sess, './logs/train/model.ckpt-10000')
            print('Loading success, global_step is %s' % global_step)
        else:
            print('No checkpoint file found')
            return

    saver.restore(sess, './logs/train/model.ckpt-10000')


    for step in range(50):
        train_images, train_labels = sess.run([train_image_batch, train_label_batch])
        _, train_loss, train_acc = sess.run([train_op2, loss, accuracy],
                                            feed_dict={x: train_images, y_: train_labels})
        print('Step: %d, loss: %.4f, accuracy: %.4f%%' % (step, train_loss, train_acc))

    saver.restore(sess, './logs/train/model.ckpt-14999')
    '''
    '''
    #下面的try语句可以当做模板使用
    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break

            #运行计算节点,从计算节点中得到真实的image,label
            train_images, train_labels = sess.run([train_image_batch, train_label_batch])

            #运行损失, 精确度计算节点, 得到具体数值
            _, train_loss, train_acc = sess.run([train_op, loss, accuracy],
                                            feed_dict={x: train_images, y_: train_labels})

            #每到50步或者最后一步就当前batch的损失值大小和准确度大小
            if step % 50 == 0 or (step + 1) == MAX_STEP:
                print('Step: %d, loss: %.4f, accuracy: %.4f%%' % (step, train_loss, train_acc))
                #summary_str = sess.run(summary_op)
                #tra_summary_writer.add_summary(summary_str, step)

            #每到200步或者最后一步就从测试集取一个batch, 计算损失值大小和准确度
            if step % 200 == 0 or (step + 1) == MAX_STEP:

                val_images, val_labels = sess.run([val_image_batch, val_label_batch])
                val_loss, val_acc = sess.run([loss, accuracy],
                                             feed_dict={x: val_images, y_: val_labels})
                print('**  Step %d, val loss = %.2f, val accuracy = %.2f%%  **' % (step, val_loss, val_acc))

                #summary_str = sess.run(summary_op)
                #val_summary_writer.add_summary(summary_str, step)

            #每到2000步就保存一次
            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                if step == 0:
                    continue
                checkpoint_path = os.path.join(train_log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()

    coord.join(threads)
    '''
    sess.close()
def train(retrain=False):
    data_dir = '/home/rong/something_for_deep/cifar-10-batches-bin'
    npy_dir = '/home/rong/something_for_deep/vgg16.npy'
    train_log_dir = './logs/train'
    val_log_dir = './logs/val'

    train_image_batch, train_label_batch = input_data.read_cifar10(
        data_dir=data_dir, is_train=True, batch_size=BATCH_SIZE, shuffle=True)
    val_image_batch, val_label_batch = input_data.read_cifar10(
        data_dir=data_dir,
        is_train=False,
        batch_size=BATCH_SIZE,
        shuffle=False)

    #宣布图片batch和标签batch的占位符
    x = tf.placeholder(tf.float32,
                       shape=[BATCH_SIZE, IMG_W, IMG_H, IMG_CHANNELS])
    y_ = tf.placeholder(tf.int16, shape=[BATCH_SIZE, NUM_CLASSES])

    #宣布VGG16类型的变量
    vgg = model.VGG16()

    #宣布损失,精确度等关键节点
    logits = vgg.build(x, NUM_CLASSES, IS_PRETRAIN)
    loss = tools.loss(logits, y_)
    accuracy = tools.accuracy(logits, y_)

    my_global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = tools.optimize(loss, learning_rate, my_global_step)
    train_op2 = tools.optimize2(loss, learning_rate)

    saver = tf.train.Saver()  #括号那个参数不知道是干什么的
    summary_op = tf.summary.merge_all()

    #初始化所有的variable,之前我看过另外一种写法,那种写法好像废弃了
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    #从npy文件加载除了全连接之外,其他层的权重
    tools.load_with_skip(npy_dir, sess, ['fc6', 'fc7', 'fc8'])

    saver.restore(sess, './logs/train/model.ckpt-6000')
    output_graph_def = convert_variables_to_constants(
        sess, sess.graph_def, output_node_names=['fc8/relu'])

    with tf.gfile.FastGFile('vgg_6000.pb', mode='wb') as f:
        f.write(output_graph_def.SerializeToString())
    '''
    #下面的和多线程有关,暂时不懂
    coord = tf.train.Coordinator() #宣布线程管理器
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) #线程负责把文件加入队列(input_data那个file队列)

    tra_summary_writer = tf.summary.FileWriter(train_log_dir, sess.graph)
    val_summary_writer = tf.summary.FileWriter(val_log_dir, sess.graph)
    '''
    '''
    if retrain == False:
        print('Reading checkpoints')
        ckpt = tf.train.get_checkpoint_state(train_log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            saver.restore(sess, './logs/train/model.ckpt-10000')
            print('Loading success, global_step is %s' % global_step)
        else:
            print('No checkpoint file found')
            return
    
    saver.restore(sess, './logs/train/model.ckpt-10000')


    for step in range(50):
        train_images, train_labels = sess.run([train_image_batch, train_label_batch])
        _, train_loss, train_acc = sess.run([train_op2, loss, accuracy],
                                            feed_dict={x: train_images, y_: train_labels})
        print('Step: %d, loss: %.4f, accuracy: %.4f%%' % (step, train_loss, train_acc))
   
    saver.restore(sess, './logs/train/model.ckpt-14999')
    '''
    '''
    #下面的try语句可以当做模板使用
    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break

            #运行计算节点,从计算节点中得到真实的image,label
            train_images, train_labels = sess.run([train_image_batch, train_label_batch])

            #运行损失, 精确度计算节点, 得到具体数值
            _, train_loss, train_acc = sess.run([train_op, loss, accuracy],
                                            feed_dict={x: train_images, y_: train_labels})

            #每到50步或者最后一步就当前batch的损失值大小和准确度大小
            if step % 50 == 0 or (step + 1) == MAX_STEP:
                print('Step: %d, loss: %.4f, accuracy: %.4f%%' % (step, train_loss, train_acc))
                #summary_str = sess.run(summary_op)
                #tra_summary_writer.add_summary(summary_str, step)

            #每到200步或者最后一步就从测试集取一个batch, 计算损失值大小和准确度
            if step % 200 == 0 or (step + 1) == MAX_STEP:

                val_images, val_labels = sess.run([val_image_batch, val_label_batch])
                val_loss, val_acc = sess.run([loss, accuracy],
                                             feed_dict={x: val_images, y_: val_labels})
                print('**  Step %d, val loss = %.2f, val accuracy = %.2f%%  **' % (step, val_loss, val_acc))

                #summary_str = sess.run(summary_op)
                #val_summary_writer.add_summary(summary_str, step)

            #每到2000步就保存一次
            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                if step == 0:
                    continue
                checkpoint_path = os.path.join(train_log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()

    coord.join(threads)
    '''
    sess.close()
Esempio n. 18
0
def train():

    pre_trained_weights = './VGG16_pretrain/vgg16.npy'
    data_dir = config.dataPath
    train_log_dir = './logs2/train/'
    val_log_dir = './logs2/val/'

    with tf.name_scope('input'):
        train_image_batch, train_label_batch = input_data.read_cifar10(
            data_dir, is_train=True, batch_size=BATCH_SIZE, shuffle=True)

        val_image_batch, val_label_batch = input_data.read_cifar10(
            data_dir, is_train=False, batch_size=BATCH_SIZE, shuffle=False)

    logits = VGG.VGG16(train_image_batch, N_CLASSES, IS_PRETRAIN)
    loss = tools.loss(logits, train_label_batch)
    accuracy = tools.accuracy(logits, train_label_batch)
    my_global_step = tf.Variable(0, trainable=False, name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = tools.optimize(loss, learning_rate, my_global_step)

    x = tf.placeholder(dtype=tf.float32, shape=[BATCH_SIZE, IMG_H, IMG_W, 3])
    y_ = tf.placeholder(dtype=tf.int32, shape=[BATCH_SIZE, N_CLASSES])
    tf.summary.image('input', x, 10)
    saver = tf.train.Saver(tf.global_variables())

    summary_op = tf.summary.merge_all()
    '''if exit checkpoint
            restore
       else:
            init
    '''
    print('Reading checkpoint...')
    ckpt = tf.train.get_checkpoint_state(train_log_dir)
    sess = tf.Session()
    if ckpt and ckpt.model_checkpoint_path:
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Load success, global step: %s' % global_step)
    else:
        init = tf.global_variables_initializer()
        sess.run(init)
        # load pretrain weights
        tools.load_with_skip(pre_trained_weights, sess, ['fc6', 'fc7', 'fc8'])
        print('Load pre_trained_weights success!!!')

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    train_summary_writer = tf.summary.FileWriter(train_log_dir, sess.graph)
    val_summary_writer = tf.summary.FileWriter(val_log_dir, sess.graph)

    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break

            train_images, train_labels = sess.run(
                [train_image_batch, train_label_batch])
            _, train_loss, train_accuracy = sess.run(
                [train_op, loss, accuracy],
                feed_dict={
                    x: train_images,
                    y_: train_labels
                })
            if step % 50 == 0 or (step + 1) == MAX_STEP:
                print("Step: %d, loss: %.4f, accuracy: %.4f%%" %
                      (step, train_loss, train_accuracy))
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           x: train_images,
                                           y_: train_labels
                                       })
                train_summary_writer.add_summary(summary_str, step)

            if step % 200 == 0 or (step + 1) == MAX_STEP:
                val_images, val_labels = sess.run(
                    [val_image_batch, val_label_batch])
                val_loss, val_accuracy = sess.run([loss, accuracy],
                                                  feed_dict={
                                                      x: val_images,
                                                      y_: val_labels
                                                  })
                print("** Step: %d, loss: %.4f, accuracy: %.4f%%" %
                      (step, val_loss, val_accuracy))
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           x: train_images,
                                           y_: train_labels
                                       })
                val_summary_writer.add_summary(summary_str, step)

            if step % 2000 == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(train_log_dir, 'model.ckpt')
                saver.save(sess, save_path=checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limited reached')
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()