Exemplo n.º 1
0
 def load_data_util(self):
     image_options = {
         'resize': True,
         'resize_size': IMG_SIZE
     }  #resize all your images
     train_records, valid_records = RDL.read_dataset(
         data_dir)  #get read lists
     train_dataset_reader = BDR.BatchDatset(train_records, image_options)
     validation_dataset_reader = BDR.BatchDatset(valid_records,
                                                 image_options)
     return train_dataset_reader, validation_dataset_reader
Exemplo n.º 2
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")

    pred_annotation, logits = inference(image, keep_probability)

    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    #if set the classes num=2,sparse_softmax_cross_entropy_with_logits will be wrong!
    tf.summary.scalar("training_entropy_loss", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()
    '''
    # For Training
    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.read_dataset(FLAGS.data_dir) #get read lists
    print(len(train_records)) #44
    print(len(valid_records)) #10
	
    '''
    # For Testing
    print("Setting up testing image reader...")
    train_records, valid_records = scene_parsing.read_dataset(
        FLAGS.test_dir)  #get read lists
    print(len(train_records))  #44
    print(len(valid_records))  #10

    print("Setting up dataset reader")
    image_options = {
        'resize': True,
        'resize_size': IMAGE_SIZE
    }  #resize all your images

    #if train mode,get datas batch by bactch
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records,
                                                   image_options)

    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver(max_to_keep=2)
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(
        FLAGS.logs_dir)  #if model has been trained,restore it
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if FLAGS.mode == "train":
        for itr in xrange(MAX_ITERATION):
            train_images, train_annotations = train_dataset_reader.next_batch(
                FLAGS.batch_size)
            #print(train_images.shape)
            #print(train_annotations.shape)
            #print(itr)
            feed_dict = {
                image: train_images,
                annotation: train_annotations,
                keep_probability: 0.85
            }
            sess.run(train_op, feed_dict=feed_dict)

            if itr % 15 == 0:
                valid_images, valid_annotations = validation_dataset_reader.next_batch(
                    FLAGS.validation_batch_size)

                train_loss, summary_str = sess.run([loss, summary_op],
                                                   feed_dict=feed_dict)
                valid_loss = sess.run(loss,
                                      feed_dict={
                                          image: valid_images,
                                          annotation: valid_annotations,
                                          keep_probability: 1.0
                                      })
                pre_train_image = sess.run(pred_annotation,
                                           feed_dict={
                                               image: train_images,
                                               keep_probability: 1.0
                                           })
                pre_valid_image = sess.run(pred_annotation,
                                           feed_dict={
                                               image: valid_images,
                                               keep_probability: 1.0
                                           })

                sensitivity_list_t, FPavg_list_t = computeFROC(
                    pre_train_image, train_annotations, allowedDistance,
                    range_threshold)
                froc_score_t = get_FROC_avg_score(sensitivity_list_t,
                                                  nbr_of_thresholds)
                #f1_score = metrics.f1_score(valid_annotations_flat, pre_valid_image_flat)
                sensitivity_list, FPavg_list = computeFROC(
                    pre_valid_image, valid_annotations, allowedDistance,
                    range_threshold)
                froc_score = get_FROC_avg_score(sensitivity_list,
                                                nbr_of_thresholds)

                #SN_score_tb = tf.Summary(value = [tf.Summary.Value(tag="f1_score", simple_value=f1_score)])
                froc_score_t_tb = tf.Summary(value=[
                    tf.Summary.Value(tag="froc_score_training",
                                     simple_value=froc_score_t)
                ])
                froc_score_tb = tf.Summary(value=[
                    tf.Summary.Value(tag="froc_score_validation",
                                     simple_value=froc_score)
                ])
                validation_loss = tf.Summary(value=[
                    tf.Summary.Value(tag="validation_loss",
                                     simple_value=valid_loss)
                ])
                print('froc_score_traing:', froc_score_t)
                print('froc_score:', froc_score)

                #summary_writer.add_summary(SN_score_tb,itr)
                summary_writer.add_summary(summary_str, itr)
                summary_writer.add_summary(froc_score_t_tb, itr)
                summary_writer.add_summary(froc_score_tb, itr)
                summary_writer.add_summary(validation_loss, itr)
                summary_writer.flush()

                print("Step: %d, learning_rate:%g" %
                      (itr, FLAGS.learning_rate))
                print("Step: %d, Train_loss:%g" % (itr, train_loss))
                print("Step: %d, Validation_loss:%g" % (itr, valid_loss))
                sys.stdout.flush()

            if itr % 5000 == 0:
                #if itr % 2000 == 0:
                saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
                sys.stdout.flush()
            '''
            if itr == 30000:
                FLAGS.learning_rate = 1e-6
            if itr == 40000:
                FLAGS.learning_rate = 1e-7
            '''

    elif FLAGS.mode == "visualize":
        print("ONE TEST!!!")
        #valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.validation_batch_size)
        valid_images, valid_annotations = validation_dataset_reader.get_test(
            0)  #input num is the index of test instance
        pred = sess.run(pred_annotation,
                        feed_dict={
                            image: valid_images,
                            annotation: valid_annotations,
                            keep_probability: 1.0
                        })
        valid_annotations = np.squeeze(valid_annotations, axis=3)
        pred = np.squeeze(pred, axis=3)

        for itr in range(FLAGS.validation_batch_size):
            utils.save_image(valid_images[itr].astype(np.uint8),
                             FLAGS.logs_dir + 'images/',
                             name="inp_" + str(1 + itr))
            utils.save_image(valid_annotations[itr].astype(np.uint8),
                             FLAGS.logs_dir + 'images/',
                             name="gt_" + str(1 + itr))
            utils.save_image(pred[itr].astype(np.uint8),
                             FLAGS.logs_dir + 'images/',
                             name="pred_" + str(1 + itr))
            print("Saved image: %d" % itr)
            sys.stdout.flush()

    else:  # FLAGS.mode == "test":
        print("GETTING ALL TEST!!!")
        test_images, test_annotations = validation_dataset_reader.get_test_avg(
        )
        pred = sess.run(pred_annotation,
                        feed_dict={
                            image: test_images,
                            annotation: test_annotations,
                            keep_probability: 1.0
                        })
        #saver = tf.train.import_meta_graph('model.ckpt-5000.meta')
        test_annotations = np.squeeze(test_annotations, axis=3)
        pred = np.squeeze(pred, axis=3)

        for itr in range(len(pred)):
            utils.save_image(test_images[itr].astype(np.uint8),
                             FLAGS.logs_dir + 'images/',
                             name="inp_" + str(1 + itr))
            utils.save_image(test_annotations[itr].astype(np.uint8),
                             FLAGS.logs_dir + 'images/',
                             name="gt_" + str(1 + itr))
            utils.save_image(pred[itr].astype(np.uint8),
                             FLAGS.logs_dir + 'images/',
                             name="pred_" + str(1 + itr))

            print("Saved image: %d" % itr)

            sys.stdout.flush()