Ejemplo n.º 1
0
def run_training():
    with tf.Graph().as_default():
        train_images, train_labels = read_data.inputs(data_set='train', batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS)
        train_logits = model_cnn.inference(train_images)
        train_accuracy = model_cnn.evaluation(train_logits, train_labels)
        tf.scalar_summary('train_accuracy', train_accuracy)

        loss = model_cnn.loss(train_logits, train_labels)

        train_op = model_cnn.training(loss)

        saver = tf.train.Saver(tf.all_variables(), max_to_keep=1)

        summary_op = tf.merge_all_summaries()

        init_op = tf.initialize_all_variables()

        sess = tf.Session()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        summary_writer = tf.train.SummaryWriter(TRAIN_DATA_DIR, sess.graph)

        try:
            step = 0
            num_iter_per_epoch = int(math.ceil(NUM_TRAIN_EXAMPLES / BATCH_SIZE))

            while not coord.should_stop():
                start_time = time.time()

                _, loss_value, train_acc_val = sess.run([train_op, loss, train_accuracy])

                duration = time.time() - start_time
                assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

                if step % 10 == 0:
                    print('Step %d : loss = %.5f , training accuracy = %.1f (%.3f sec)'
                          % (step, loss_value, train_acc_val, duration))
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)

                if step % num_iter_per_epoch == 0 and step > 0: # Do not save for step 0
                    num_epochs = int(step / num_iter_per_epoch)
                    saver.save(sess, CHECKPOINT_FILE_PATH, global_step=step)
                    print('epochs done on training dataset = %d' % num_epochs)
                    eval_cnn.evaluate('validation', checkpoint_dir=TRAIN_DATA_DIR)

                step += 1

        except tf.errors.OutOfRangeError:
            print('Done training for %d epochs, %d steps' % (NUM_EPOCHS, step))
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()
Ejemplo n.º 2
0
def evaluate(data_set,checkpoint_dir = 'tmp/train_data'):
	with tf.Graph().as_default():
		if data_set == 'validation':
			num_examples = NUM_VALIDATION_EXAMPLES
		elif data_set == 'train':
			num_examples = NUM_TRAIN_EXAMPLES
		else:
			raise ValueError('data_set should be one of \'train\', \'validation\'')

		images, labels = read_data.inputs(data_set=data_set, batch_size=BATCH_SIZE, num_epochs=None)
		logits = model.inference(images)
		accuracy_curr_batch = model.evaluation(logits, labels)

		mov_avg_obj = tf.train.ExponentialMovingAverage(model.MOVING_AVERAGE_DECAY)
		variables_to_restore = mov_avg_obj.variables_to_restore()
		saver = tf.train.Saver(variables_to_restore)

		with tf.Session() as sess:
			ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
			if ckpt and ckpt.model_checkpoint_path:
				saver.restore(sess, ckpt.model_checkpoint_path)
			else:
				print('No checkpoint file found at %s' % checkpoint_dir)
				return

			coord = tf.train.Coordinator()

			try:
				threads = []
				for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
					threads.extend(qr.create_threads(sess, coord, daemon=True, start=True))

				num_iter = int(math.ceil(num_examples / BATCH_SIZE))
				step = 0
				acc_full_epoch = 0
				while step < num_iter and not coord.should_stop():
					acc_batch_val = sess.run(accuracy_curr_batch)
					acc_full_epoch += acc_batch_val
					step += 1

				acc_full_epoch /= num_iter
				tf.summary.scalar('validation_accuracy', acc_full_epoch)
				summary_op = tf.summary.merge_all()
				summary_writer = tf.summary.FileWriter(EVAL_DATA_DIR)
				summary_str = sess.run(summary_op)
				summary_writer.add_summary(summary_str, step)

				print('Accuracy on full %s dataset = %.1f' % (data_set, acc_full_epoch))


			except Exception as e:
				coord.request_stop(e)

			coord.request_stop()

			coord.join(threads)
Ejemplo n.º 3
0
def inputs(eval_data):
    if not FLAGS.data_dir2:
        raise ValueError('Data could not been found!')
    data_dir = os.path.join(FLAGS.data_dir2, 'cifar-10-batches-bin')
    images, labels = read_data.inputs(eval_data=eval_data,
                                        data_dir=data_dir,
                                        batch_size=FLAGS.batch_size2)
    if FLAGS.use_fp162:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
Ejemplo n.º 4
0
def inputs(eval_data):
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')

    data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
    images, labels = cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, batch_size=FLAGS.batch_size)

    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
Ejemplo n.º 5
0
def inputs(eval_data):
    """Construct input for CIFAR evaluation using the Reader ops.
  Args:
    eval_data: bool, indicating if one should use the train or eval data set.
  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  Raises:
    ValueError: If no data_dir
  """
    if not FLAGS.data_dir2:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir2, 'cifar-10-batches-bin')
    images, labels = cifar10_input.inputs(eval_data=eval_data,
                                          data_dir=data_dir,
                                          batch_size=FLAGS.batch_size2)
    if FLAGS.use_fp162:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
Ejemplo n.º 6
0
def evaluate(test_set, path):
    with tf.Graph().as_default():

        images, labels = inputs(test_set)

        logits = inference(train=False, images=images)
        test_acc = accuracy(logits, labels)

        saver = tf.train.Saver(tf.all_variables())

        sess = tf.Session()
        coord = tf.train.Coordinator()
        saver.restore(sess=sess, save_path=path)

        threads = tf.train.start_queue_runners(coord=coord, sess=sess)

        try:
            true_count = 0
            if test_set == "valid.tfrecords":
                num_records = valid_records
            else:
                num_records = test_records

            step = 0
            while step < int(num_records / batch_size):
                acc = sess.run(test_acc)
                true_count += np.sum(acc)
                step += 1

        except tf.errors.OutOfRangeError as e:
            print "Issues: ", e
        finally:
            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)
            sess.close()

        return 100 * (float(true_count) / num_records)
Ejemplo n.º 7
0
def evaluate(test_set, path):
    with tf.Graph().as_default():

        images, labels = inputs(test_set)

        logits = inference(train=False, images=images)
        test_acc = accuracy(logits, labels)

        saver = tf.train.Saver(tf.all_variables())

        sess = tf.Session()
        coord = tf.train.Coordinator()
        saver.restore(sess=sess, save_path=path)

        threads = tf.train.start_queue_runners(coord=coord, sess=sess)

        try:
            true_count = 0
            if test_set == 'valid.tfrecords':
                num_records = valid_records
            else:
                num_records = test_records

            step = 0
            while step < int(num_records / batch_size):
                acc = sess.run(test_acc)
                true_count += np.sum(acc)
                step += 1

        except tf.errors.OutOfRangeError as e:
            print 'Issues: ', e
        finally:
            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)
            sess.close()

        return 100 * (float(true_count) / num_records)
Ejemplo n.º 8
0
def train_model(train_set_path, validation_set_path, save_model_path):
    x = tf.placeholder(tf.float32, shape=IMAGE_SET_SHAPE)
    y = tf.placeholder(tf.float32, shape=LABEL_SET_SHAPE)

    net = build_network(x)

    y_out = net.outputs
    y_out = tf.reshape(y_out, shape=LABEL_SET_SHAPE)

    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_out))

    train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

    y_arg = tf.reshape(tf.argmax(y_out, 1), shape=[BATCH_SIZE])
    correct_prediction = tf.equal(y_arg, tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    tri_img, tri_lbl = rd.inputs(path=train_set_path,
                                 batch_size=BATCH_SIZE,
                                 num_epochs=NUM_EPOCHS,
                                 patch_size=PATCH_SIZE,
                                 channel_num=CHANNEL_NUM)

    val_img, val_lbl = rd.inputs(path=validation_set_path,
                                 batch_size=BATCH_SIZE,
                                 num_epochs=NUM_EPOCHS,
                                 patch_size=PATCH_SIZE,
                                 channel_num=CHANNEL_NUM)

    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())

    sess = tf.InteractiveSession()
    sess.run(init)

    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        step = 1
        while not coord.should_stop():
            [tris, tril] = sess.run([tri_img, tri_lbl])
            fd_train = {x: tris, y: reshape_labels(tril)}

            if step % 10 == 0 or step == 1:
                [vals, vall] = sess.run([val_img, val_lbl])
                fd_val = {x: vals, y: reshape_labels(vall)}

                print("----------\nStep {}:\n----------".format(step))

                tri_accuracy = accuracy.eval(feed_dict=fd_train)
                print("Training accuracy {0:.6f}".format(tri_accuracy))
                tri_cost = loss.eval(feed_dict=fd_train)
                print("Training cost is {0:.6f}".format(tri_cost))

                val_accuracy = accuracy.eval(feed_dict=fd_val)
                print("Validation accuracy {0:.6f}".format(val_accuracy))
                val_cost = loss.eval(feed_dict=fd_val)
                print("Validation cost is {0:.6f}".format(val_cost))

            sess.run(train_step, feed_dict=fd_train)
            step += 1
            time.sleep(1)

    except tf.errors.OutOfRangeError:
        print('---------\nTraining has stopped.')
    finally:
        coord.request_stop()

    tl.files.save_npz(net.all_params, save_model_path)
    coord.join(thread)
    sess.close()
Ejemplo n.º 9
0
def evaluate(data_set, checkpoint_dir = 'tmp/train_data'):
    with tf.Graph().as_default():

        if data_set == 'validation':
            num_examples = NUM_VALIDATION_EXAMPLES
        elif data_set == 'test':
            num_examples = NUM_TEST_EXAMPLES
        elif data_set == 'train':
            num_examples = NUM_TRAIN_EXAMPLES
        else:
            raise ValueError('data_set should be one of \'train\', \'validation\' or \'test\'')

        # Don't specify number of epochs in validation set, otherwise that limits the training duration as the
        # validation set is 10 times smaller than the training set
        images, labels = read_data.inputs(data_set=data_set, batch_size=BATCH_SIZE, num_epochs=None)
        logits = model_cnn.inference(images)
        accuracy_curr_batch = model_cnn.evaluation(logits, labels)

        # Restore moving averages of the trained variables
        mov_avg_obj = tf.train.ExponentialMovingAverage(model_cnn.MOVING_AVERAGE_DECAY)
        variables_to_restore = mov_avg_obj.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found at %s' % checkpoint_dir)
                return

            coord = tf.train.Coordinator()

            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(qr.create_threads(sess, coord, daemon=True, start=True))

                num_iter = int(math.ceil(num_examples / BATCH_SIZE))
                step = 0
                acc_full_epoch = 0
                while step < num_iter and not coord.should_stop():
                    acc_batch_val = sess.run(accuracy_curr_batch)
                    acc_full_epoch += acc_batch_val
                    step += 1

                acc_full_epoch /= num_iter
                tf.scalar_summary('validation_accuracy', acc_full_epoch)
                summary_op = tf.merge_all_summaries()
                summary_writer = tf.train.SummaryWriter(EVAL_DATA_DIR)
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

                print('Accuracy on full %s dataset = %.1f' % (data_set, acc_full_epoch))


            except Exception as e:
                coord.request_stop(e)

            coord.request_stop()

            coord.join(threads)