コード例 #1
0
def main(argv=None):
    batch_size = test_batch_size
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image_batch, img_name = tf.train.batch([img,filename_queue[0]], batch_size)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    logits = FCN4.inference(image_batch, keep_probability)

    if pre_threshold:
        pred_annotation = utils.argmax_pre_threshold(logits)
        pred_annotation = tf.reshape(pred_annotation, [batch_size, IMAGE_SIZE_W, IMAGE_SIZE_L, 1])
    else:
        pred_annotation = utils.argmax_threshold(logits, threshold)
        pred_annotation = tf.reshape(pred_annotation, [batch_size, IMAGE_SIZE_W, IMAGE_SIZE_L, 1])

    print("Setting up Saver...")
    saver = tf.train.Saver()

    ckpt = tf.train.get_checkpoint_state(saved_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if not os.path.isdir(test_saved_dir):
        os.makedirs(test_saved_dir)

    for counter in range(test_batch_num):
        pred, image_test, image_name = sess.run([pred_annotation, image_batch, img_name],
                                               feed_dict={keep_probability: 1.0})

        image_name = image_name[0].decode()
        image_name = image_name.split('/')[-1].strip('.png')

        image_test = np.reshape(image_test, [IMAGE_SIZE_W, IMAGE_SIZE_L, 3])
        pred = np.squeeze(pred, axis=3)
        pred = np.reshape(pred, [IMAGE_SIZE_W, IMAGE_SIZE_L])

        if pre_threshold:
            if not os.path.isdir(pre_threshold_dir):
                os.makedirs(pre_threshold_dir)
            f = h5.File(pre_threshold_dir + image_name + '.h5', 'w')
            f['data'] = pred
            f.close()
            pred = utils.apply_threshold(pred, threshold)

        utils.save_image(image_test, test_saved_dir + '' 'image/',
                         name=image_name + "_image")
        utils.save_image(pred, test_saved_dir + 'pred/',
                         name=image_name + "_pred", category=True)
        print("Saved no. %s image: %s" % (counter, image_name + '.png'))

    coord.request_stop()
    coord.join(threads)
コード例 #2
0
ファイル: FCN.py プロジェクト: FightingZhen/FCN
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")

    if mode == "Train":
        batch_size = training_batch_size
    elif mode == "Test":
        batch_size = test_batch_size

    image_batch, mask_batch, image_name= tf.train.batch([img, mask, filename_queue[0]], batch_size)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    logits = inference(image_batch, keep_probability)

    pred_annotation = utils.argmax_threshold(logits, threshold)
    pred_annotation = tf.reshape(pred_annotation, [batch_size, IMAGE_SIZE_W, IMAGE_SIZE_L, 1])
    logits = tf.reshape(logits, [batch_size, IMAGE_SIZE_W, IMAGE_SIZE_L, 2])

    if mode == 'Train':
        tf.summary.image("input_image", image_batch, max_outputs=10)
        tf.summary.image("ground_truth", tf.cast(mask_batch, tf.uint8) * 255, max_outputs=10)
        tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8) * 255, max_outputs=10)

        loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                              labels=tf.squeeze(mask_batch,
                                                                                                squeeze_dims=[3]),
                                                                              name="entropy")))
        tf.summary.scalar("entropy", loss)

        train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

    if mode == "Train":
        print("Setting up summary op...")
        summary_op = tf.summary.merge_all()
    else:
        summary_op = None

    saver = tf.train.Saver()
    if mode == "Train":
        print("Setting up Saver...")
        summary_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    else:
        summary_writer = None

    sess.run(tf.global_variables_initializer())

    if mode == "Test":
        ckpt = tf.train.get_checkpoint_state(saved_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model restored...")

    if mode == "Train":
        for itr in range(MAX_ITERATION):
            feed_dict = {keep_probability: 0.85}

            sess.run(train_op, feed_dict=feed_dict)

            if itr % 10 == 0:
                train_loss, summary_str, lr = sess.run([loss, summary_op, learning_rate], feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g, learning_rate: %g" % (itr, train_loss, lr))
                summary_writer.add_summary(summary_str, itr)

                saver.save(sess, saved_dir + 'model.ckpt', global_step=global_step)

    elif mode == "Test":
        if not os.path.isdir(test_saved_dir):
            os.makedirs(test_saved_dir)

        for counter in range(test_batch_num):
            pred, image_test, mask_test, img_name = sess.run([pred_annotation, image_batch, mask_batch, image_name],
                                                   feed_dict={keep_probability: 1.0})

            img_name = img_name[0].decode()
            img_name = img_name.split('/')[-1].strip('.png')
            print(img_name)

            mask_test = np.squeeze(mask_test, axis=3)
            pred = np.squeeze(pred, axis=3)

            image_test = np.reshape(image_test, [IMAGE_SIZE_W, IMAGE_SIZE_L, 3])
            pred = np.reshape(pred, [IMAGE_SIZE_W, IMAGE_SIZE_L])
            mask_test = np.reshape(mask_test, [IMAGE_SIZE_W, IMAGE_SIZE_L])

            utils.save_image(image_test, test_saved_dir + 'image/',
                             name=img_name + "_image")
            utils.save_image(mask_test, test_saved_dir + 'mask/',
                             name=img_name + "_mask", category=True)
            utils.save_image(pred, test_saved_dir + 'pred',
                             name=img_name + "_pred", category=True)
            print("Saved image: %s" % (img_name + '.png'))

    coord.request_stop()
    coord.join(threads)