コード例 #1
0
		def enqueue_batches():
			while not coord.should_stop():
				im, l = tu.read_batch(batch_size, train_img_path, wnid_labels)
				sess.run(enqueue_op, feed_dict={x: im,y: l})
コード例 #2
0
ファイル: train.py プロジェクト: liuxiaoan8008/bias_net
 def enqueue_batches():
     while not coord.should_stop():
         im, l = tu.read_batch(batch_size, train_img_path, wnid_labels)
         sess.run(enqueue_op, feed_dict={x: im, y: l})
コード例 #3
0
ファイル: disAlexNet.py プロジェクト: warmstar1986/disAlexNet
            uninitalized_variables=sess.run(variables_check_op)
	    if(len(uninitalized_variables.shape) == 1):
		state = True
	
	step = 0
	cost = 0
	final_accuracy = 0
	start_time = time.time()
	batch_time = time.time()
	epoch_time = time.time()
	while (not sv.should_stop()):
	    #Read batch_size data
	    val_x, val_y = tu.read_validation_batch_V2(500, '/root/data/ILSVRC/Data/CLS-LOC/val/', '/root/code/disAlexNet/val_10.txt')
	    for e in range(Epoch):
		for i in range(num_batches):
		    batch_x, batch_y = tu.read_batch(batch_size, train_img_path, wnid_labels)
                    _, cost, step = sess.run([train_op, cross_entropy, global_step], feed_dict={x: batch_x, y_: batch_y, keep_prob: 0.5})
		    final_accuracy = sess.run(accuracy, feed_dict = {x: val_x, y_: val_y, keep_prob: 1.0})
		    print("Step: %d," % (step+1), 
			        " Accuracy: %.4f," % final_accuracy,
			        " Loss: %f" % cost,
			        " Bctch_Time: %fs" % float(time.time()-batch_time))
	    	    batch_time = time.time()
		    re = str(step+1)+","+str(final_accuracy)+","+str(float(time.time()-batch_time))+","+str(cost)
		    save = open("test.csv", "a+")
		    save.write(re+"\r\n")
		    save.close()
	    	print("Epoch: %d," % (e+1), 
			" Accuracy: %.4f," % final_accuracy,
			" Loss: %f" % cost,
			" Epoch_Time: %fs" % float(time.time()-epoch_time),
コード例 #4
0
def train(
        epochs,
        batch_size,
        learning_rate,
        dropout,
        momentum,
        lmbda,
        resume,
        display_step,
        test_step,
        ckpt_path,
        summary_path
):

    train_img_path = '/var/data/bias_data/image/train'
    evaluate_path = '/var/data/bias_data/image/train'
    num_whole_images = 60000
    num_batches = int(float(num_whole_images) / batch_size)
    wnid_labels = ['cheer_out', 'fearful_out', 'happy_out', 'joy_out', 'rage_out', 'sorrow_out']

    x = tf.placeholder(tf.float32, [None, 150, 150, 3])
    y = tf.placeholder(tf.float32, [None, 6])

    lr = tf.placeholder(tf.float32)
    keep_prob = tf.placeholder(tf.float32)

    # queue of examples being filled on the cpu
    with tf.device('/cpu:0'):
        q = tf.FIFOQueue(batch_size * 3, [tf.float32, tf.float32], shapes=[[150, 150, 3], [6]])
        enqueue_op = q.enqueue_many([x, y])
        x_b, y_b = q.dequeue_many(batch_size)

    pred, prob = alexnet.classifier(x_b, keep_prob)

    # cross-entropy and weight decay
    with tf.name_scope('cross_entropy'):
        cross_entropy = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y_b, name='cross-entropy'))

    with tf.name_scope('l2_loss'):
        l2_loss = tf.reduce_sum(lmbda * tf.stack([tf.nn.l2_loss(v) for v in tf.get_collection('weights')]))
        tf.summary.scalar('l2_loss', l2_loss)

    with tf.name_scope('loss'):
        loss = cross_entropy + l2_loss
        tf.summary.scalar('loss', loss)

    # accuracy
    with tf.name_scope('accuracy'):
        correct = tf.equal(tf.argmax(pred, 1), tf.argmax(y_b, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

    global_step = tf.Variable(0, trainable=False)
    epoch = tf.div(global_step, num_batches)

    # momentum optimizer
    with tf.name_scope('optimizer'):
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=momentum).minimize(loss,
                                                                                             global_step=global_step)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()

    coord = tf.train.Coordinator()
    init = tf.global_variables_initializer()

    with tf.Session(config=tf.ConfigProto()) as sess:
        if resume:
            saver.restore(sess, os.path.join(ckpt_path, 'alexnet-cnn.ckpt'))
        else:
            sess.run(init)

        # enqueuing batches procedure
        def enqueue_batches():
            while not coord.should_stop():
                im, l = tu.read_batch(batch_size, train_img_path, wnid_labels)
                sess.run(enqueue_op, feed_dict={x: im, y: l})

        # creating and starting parallel threads to fill the queue
        num_threads = 3
        threads = []
        for i in range(num_threads):
            t = threading.Thread(target=enqueue_batches)
            t.setDaemon(True)
            t.start()
            threads.append(t)

        # operation to write logs for tensorboard visualization
        train_writer = tf.summary.FileWriter(os.path.join(summary_path, 'train'), sess.graph)

        valid_batch_size = 126
        val_im, val_cls = tu.read_batch(valid_batch_size, evaluate_path, wnid_labels)

        start_time = time.time()
        for e in range(sess.run(epoch), epochs):
            for i in range(num_batches):

                _, step = sess.run([optimizer, global_step], feed_dict={lr: learning_rate, keep_prob: dropout})
                # train_writer.add_summary(summary, step)

                # decaying learning rate
                if step == 170000 or step == 350000:
                    learning_rate /= 10

                # display current training informations
                if step % display_step == 0:
                    c, a = sess.run([loss, accuracy], feed_dict={lr: learning_rate, keep_prob: 1.0})
                    print (
                        'Epoch: {:03d} Step/Batch: {:09d} --- Loss: {:.7f} Training accuracy: {:.4f}'.format(e, step, c,
                                                                                                             a))

                # make test and evaluate validation accuracy
                if step % test_step == 0:
                    v_a = sess.run(accuracy, feed_dict={x_b: val_im, y_b: val_cls, lr: learning_rate, keep_prob: 1.0})
                    # intermediate time
                    int_time = time.time()
                    print ('Elapsed time: {}'.format(tu.format_time(int_time - start_time)))
                    print ('Validation accuracy: {:.04f}'.format(v_a))
                    # save weights to file
                    save_path = saver.save(sess, os.path.join(ckpt_path, 'alexnet-cnn.ckpt'))
                    print('Variables saved in file: %s' % save_path)

        end_time = time.time()
        print ('Elapsed time: {}'.format(tu.format_time(end_time - start_time)))
        save_path = saver.save(sess, os.path.join(ckpt_path, 'alexnet-cnn.ckpt'))
        print('Variables saved in file: %s' % save_path)

        coord.request_stop()
        coord.join(threads)