Пример #1
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")
    soft_annotation = tf.placeholder(tf.float32,
                                     shape=[None, IMAGE_SIZE, IMAGE_SIZE, 2],
                                     name="soft_annotation")
    pred_annotation_value, pred_annotation, logits, pred_prob = inference(
        image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datetime.datetime.now()) + '.txt'
    filename = "logs_%s%s.txt" % (FLAGS.mode, datetime.datetime.now())
    path_ = os.path.join(FLAGS.logs_dir, filename)
    logs_file = open(path_, 'w')
    logs_file.write("The logs file is created at %s\n" %
                    datetime.datetime.now())
    logs_file.write("The mode is %s\n" % (FLAGS.mode))
    logs_file.write(
        "The train data batch size is %d and the validation batch size is %d.\n"
        % (FLAGS.batch_size, FLAGS.v_batch_size))
    logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
    logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" %
                    (IMAGE_SIZE, MAX_ITERATION))
    logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)

    print("Setting up image reader...")
    logs_file.write("Setting up image reader...\n")
    train_records, valid_records = scene_parsing.my_read_dataset(
        FLAGS.data_dir)
    print('number of train_records', len(train_records))
    print('number of valid_records', len(valid_records))
    logs_file.write('number of train_records %d\n' % len(train_records))
    logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'check_training':
        train_dataset_reader = dataset_soft.BatchDatset(
            train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if FLAGS.mode == "accurary":
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filenames, if_con, start, end = validation_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            valid_loss, pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            accu_iou, accu_pixel = accu.caculate_accurary(
                pred_anno, valid_annotations)
            print("Ture %d ---> the data from %d to %d" % (count, start, end))
            print("%s ---> Validation_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou))
            #Output logs.
            logs_file.write("Ture %d ---> the data from %d to %d\n" %
                            (count, start, end))
            logs_file.write("%s ---> Validation_pixel_accuary: %g\n" %
                            (datetime.datetime.now(), accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g\n" %
                            (datetime.datetime.now(), accu_iou))

            accu_iou_t = accu_iou_t + accu_iou
            accu_pixel_t = accu_pixel_t + accu_pixel
        print("%s ---> Total validation_pixel_accuary: %g" %
              (datetime.datetime.now(), accu_pixel_t / count))
        print("%s ---> Total validation_iou_accuary: %g" %
              (datetime.datetime.now(), accu_iou_t / count))
        #Output logs
        logs_file.write("%s ---> Total validation_pixel_accurary: %g\n" %
                        (datetime.datetime.now(), accu_pixel_t / count))
        logs_file.write("%s ---> Total validation_iou_accurary: %g\n" %
                        (datetime.datetime.now(), accu_iou_t / count))

    elif FLAGS.mode == "all_visualize":

        re_save_dir = "%s%s" % (FLAGS.result_dir, datetime.datetime.now())
        logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
        logs_file.write("The number of part visualization is %d.\n" %
                        FLAGS.v_batch_size)

        #Check the result path if exists.
        if not os.path.exists(re_save_dir):
            print("The path '%s' is not found." % re_save_dir)
            print("Create now ...")
            os.makedirs(re_save_dir)
            print("Create '%s' successfully." % re_save_dir)
            logs_file.write("Create '%s' successfully.\n" % re_save_dir)
        re_save_dir_anno = os.path.join(re_save_dir, 'anno')
        re_save_dir_pred = os.path.join(re_save_dir, 'pred')
        re_save_dir_train_heat = os.path.join(re_save_dir, 'train_heat')
        re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
        re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
        re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
        if not os.path.exists(re_save_dir_anno):
            os.makedirs(re_save_dir_anno)
        if not os.path.exists(re_save_dir_pred):
            os.makedirs(re_save_dir_pred)
        if not os.path.exists(re_save_dir_train_heat):
            os.makedirs(re_save_dir_train_heat)
        if not os.path.exists(re_save_dir_heat):
            os.makedirs(re_save_dir_heat)
        if not os.path.exists(re_save_dir_ellip):
            os.makedirs(re_save_dir_ellip)
        if not os.path.exists(re_save_dir_transheat):
            os.makedirs(re_save_dir_transheat)
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filename, if_con, start, end = validation_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            pred_value, pred, logits_, pred_prob_ = sess.run(
                [pred_annotation_value, pred_annotation, logits, pred_prob],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            #valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)
            pred_value = np.squeeze(pred_value, axis=3)
            pred_ellip = np.argmax(pred_prob_ + [0.85, 0], axis=3)
            #label_predict_pixel
            for itr in range(len(pred)):
                filename = valid_filename[itr]['filename']
                if FLAGS.anno == 'T':
                    valid_images_anno = anno_visualize(
                        valid_images[itr].copy(), valid_annotations[itr, :, :,
                                                                    1])
                    utils.save_image(valid_images_anno.astype(np.uint8),
                                     re_save_dir_anno,
                                     name="anno_" + filename)
                if FLAGS.pred == 'T':
                    valid_images_pred = pred_visualize(
                        valid_images[itr].copy(),
                        np.expand_dims(pred_ellip[itr], axis=2))
                    utils.save_image(valid_images_pred.astype(np.uint8),
                                     re_save_dir_pred,
                                     name="pred_" + filename)
                if FLAGS.train_heat == 'T':
                    heat_map = density_heatmap(
                        valid_annotations[itr, :, :, 1] / FLAGS.normal)
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_train_heat,
                                     name="trainheat_" + filename)

                if FLAGS.fit_ellip == 'T':
                    valid_images_ellip = fit_ellipse_findContours(
                        valid_images[itr].copy(),
                        np.expand_dims(pred_ellip[itr],
                                       axis=2).astype(np.uint8))
                    utils.save_image(valid_images_ellip.astype(np.uint8),
                                     re_save_dir_ellip,
                                     name="ellip_" + filename)
                if FLAGS.heatmap == 'T':
                    heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_heat,
                                     name="heat_" + filename)
                if FLAGS.trans_heat == 'T':
                    trans_heat_map = translucent_heatmap(
                        valid_images[itr],
                        heat_map.astype(np.uint8).copy())
                    utils.save_image(trans_heat_map,
                                     re_save_dir_transheat,
                                     name="trans_heat_" + filename)

    elif FLAGS.mode == 'check_training':
        re_save_dir = "%s%s" % (FLAGS.result_dir, datetime.datetime.now())
        logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
        logs_file.write("The number of part visualization is %d.\n" %
                        FLAGS.v_batch_size)

        #Check the result path if exists.
        if not os.path.exists(re_save_dir):
            print("The path '%s' is not found." % re_save_dir)
            print("Create now ...")
            os.makedirs(re_save_dir)
            print("Create '%s' successfully." % re_save_dir)
            logs_file.write("Create '%s' successfully.\n" % re_save_dir)
        re_save_dir_anno = os.path.join(re_save_dir, 'anno')
        re_save_dir_pred = os.path.join(re_save_dir, 'pred')
        re_save_dir_train_heat = os.path.join(re_save_dir, 'train_heat')
        re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
        re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
        re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
        if not os.path.exists(re_save_dir_anno):
            os.makedirs(re_save_dir_anno)
        if not os.path.exists(re_save_dir_pred):
            os.makedirs(re_save_dir_pred)
        if not os.path.exists(re_save_dir_train_heat):
            os.makedirs(re_save_dir_train_heat)
        if not os.path.exists(re_save_dir_heat):
            os.makedirs(re_save_dir_heat)
        if not os.path.exists(re_save_dir_ellip):
            os.makedirs(re_save_dir_ellip)
        if not os.path.exists(re_save_dir_transheat):
            os.makedirs(re_save_dir_transheat)
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filename, if_con, start, end = train_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            pred_value, pred, logits_, pred_prob_ = sess.run(
                [pred_annotation_value, pred_annotation, logits, pred_prob],
                feed_dict={
                    image: valid_images,
                    soft_annotation: valid_annotations,
                    keep_probability: 1.0
                })
            #valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)
            pred_value = np.squeeze(pred_value, axis=3)

            #label_predict_pixel
            for itr in range(len(pred)):
                filename = valid_filename[itr]['filename']
                if FLAGS.anno == 'T':
                    valid_images_anno = anno_visualize(
                        valid_images[itr].copy(), valid_annotations[itr, :, :,
                                                                    1])
                    utils.save_image(valid_images_anno.astype(np.uint8),
                                     re_save_dir_anno,
                                     name="anno_" + filename)
                if FLAGS.pred == 'T':
                    valid_images_pred = soft_pred_visualize(
                        valid_images[itr].copy(), pred[itr])
                    utils.save_image(valid_images_pred.astype(np.uint8),
                                     re_save_dir_pred,
                                     name="pred_" + filename)
                if FLAGS.train_heat == 'T':
                    heat_map = density_heatmap(
                        valid_annotations[itr, :, :, 1] / FLAGS.normal)
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_train_heat,
                                     name="trainheat_" + filename)

                if FLAGS.fit_ellip == 'T':
                    valid_images_ellip = fit_ellipse_findContours(
                        valid_images[itr].copy(),
                        np.expand_dims(pred[itr], axis=2).astype(np.uint8))
                    utils.save_image(valid_images_ellip.astype(np.uint8),
                                     re_save_dir_ellip,
                                     name="ellip_" + filename)
                if FLAGS.heatmap == 'T':
                    heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_heat,
                                     name="heat_" + filename)
                if FLAGS.trans_heat == 'T':
                    trans_heat_map = translucent_heatmap(
                        valid_images[itr],
                        heat_map.astype(np.uint8).copy())
                    utils.save_image(trans_heat_map,
                                     re_save_dir_transheat,
                                     name="trans_heat_" + filename)

    logs_file.close()
    if FLAGS.mode == "check_training" or FLAGS.mode == "all_visualize":
        result_logs_file = os.path.join(re_save_dir, filename)
        shutil.copyfile(path_, result_logs_file)
Пример #2
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")

    pred_annotation_value, pred_annotation, logits = inference(
        image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    #Check if has the log file
    if not os.path.exists(FLAGS.logs_dir):
        print("The logs path '%s' is not found" % FLAGS.logs_dir)
        print("Create now..")
        os.makedirs(FLAGS.logs_dir)
        print("%s is created successfully!" % FLAGS.logs_dir)

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datatime.datatime.now()) + '.txt'
    filename = "logs_%s%s.txt" % (FLAGS.mode, datetime.datetime.now())
    path_ = os.path.join(FLAGS.logs_dir, filename)
    with open(path_, 'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" %
                        datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)
        logs_file.write("The mode is %s\n" % (FLAGS.mode))
        logs_file.write(
            "The train data batch size is %d and the validation batch size is %d\n."
            % (FLAGS.batch_size, FLAGS.v_batch_size))
        logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" %
                        (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(
        FLAGS.data_dir)
    print('number of train_records', len(train_records))
    print('number of valid_records', len(valid_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records,
                                                   image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_ = 0
    for itr in xrange(MAX_ITERATION):

        train_images, train_annotations = train_dataset_reader.next_batch(
            FLAGS.batch_size)
        feed_dict = {
            image: train_images,
            annotation: train_annotations,
            keep_probability: 0.85
        }

        for ii in range(224):
            for jj in range(224):
                if train_annotations[0][ii][jj] > 0:
                    print('anno', ii, ', ', jj, ': ',
                          train_annotations[0][ii][jj])

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')
        if itr % 10 == 0:
            train_loss, summary_str = sess.run([loss, summary_op],
                                               feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g" % (itr, train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % 500 == 0:

            #Caculate the accurary at the training set.
            train_random_images, train_random_annotations = train_dataset_reader.get_random_batch_for_train(
                FLAGS.v_batch_size)
            train_loss, train_pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: train_random_images,
                    annotation: train_random_annotations,
                    keep_probability: 1.0
                })
            accu_iou_, accu_pixel_ = accu.caculate_accurary(
                train_pred_anno, train_random_annotations)
            print("%s ---> Training_loss: %g" %
                  (datetime.datetime.now(), train_loss))
            print("%s ---> Training_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel_))
            print("%s ---> Training_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou_))
            print("---------------------------")
            #Output the logs.
            num_ = num_ + 1
            logs_file.write("No.%d the itr number is %d.\n" % (num_, itr))
            logs_file.write("%s ---> Training_loss: %g.\n" %
                            (datetime.datetime.now(), train_loss))
            logs_file.write("%s ---> Training_pixel_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_pixel_))
            logs_file.write("%s ---> Training_iou_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_iou_))
            logs_file.write("---------------------------\n")

            valid_images, valid_annotations = validation_dataset_reader.next_batch(
                FLAGS.v_batch_size)
            #valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
            #keep_probability: 1.0})
            valid_loss, pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            accu_iou, accu_pixel = accu.caculate_accurary(
                pred_anno, valid_annotations)
            print("%s ---> Validation_loss: %g" %
                  (datetime.datetime.now(), valid_loss))
            print("%s ---> Validation_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou))

            #Output the logs.
            logs_file.write("%s ---> Validation_loss: %g.\n" %
                            (datetime.datetime.now(), valid_loss))
            logs_file.write("%s ---> Validation_pixel_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_iou))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
        '''         
        if itr % 10000 == 0:
                
            count=0
            if_con=True
            accu_iou_t=0
            accu_pixel_t=0
        
            while if_con:
                count=count+1
                valid_images, valid_annotations, filenames, if_con, start, end=validation_dataset_reader.next_batch_valid(FLAGS.v_batch_size)
                valid_loss,pred_anno=sess.run([loss,pred_annotation],feed_dict={image:valid_images,
                                                                                      annotation:valid_annotations,
                                                                                      keep_probability:1.0})
                accu_iou,accu_pixel=accu.caculate_accurary(pred_anno,valid_annotations)
            
                accu_iou_t=accu_iou_t+accu_iou
                accu_pixel_t=accu_pixel_t+accu_pixel
            print("No.%d toal validation data accurary" % itr)
            print("%s ---> Total validation_pixel_accurary: %g" % (datetime.datetime.now(),accu_pixel_t/count))
            print("%s ---> Total validation_iou_accurary: %g" % (datetime.datetime.now(),accu_iou_t/count))
            #Write logs file
            logs_file.write("No.%d toal validation data accurary.\n" % itr)
            logs_file.write("%s ---> Total validation_pixel_accurary: %g.\n" % (datetime.datetime.now(),accu_pixel_t/count))
            logs_file.write("%s ---> Total validation_iou_accurary: %g.\n" % (datetime.datetime.now(),accu_iou_t/count))'''
        #End the iterator
        logs_file.close()
Пример #3
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")

    pred_annotation, logits = inference(image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(
        FLAGS.data_dir)
    print('number of train_records', len(train_records))
    print('number of valid_records', len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records,
                                                   image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if FLAGS.mode == "train":
        for itr in xrange(MAX_ITERATION):
            train_images, train_annotations = train_dataset_reader.next_batch(
                FLAGS.batch_size)
            feed_dict = {
                image: train_images,
                annotation: train_annotations,
                keep_probability: 0.85
            }
            '''pred_annotation_, logits_=sess.run([pred_annotation, logits],feed_dict=feed_dict)
            print('shape of pred_anno',pred_annotation_.shape)
            print('pred_anno',pred_annotation_)
            print('shape of conv3',logits_.shape)
            print('logit',logits_)
            print('shape of train_annotations',train_annotations.shape)
            print('train_annotations',train_annotations)'''

            sess.run(train_op, feed_dict=feed_dict)

            if itr % 10 == 0:
                train_loss, summary_str = sess.run([loss, summary_op],
                                                   feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g" % (itr, train_loss))
                summary_writer.add_summary(summary_str, itr)

            if itr % 500 == 0:
                valid_images, valid_annotations = validation_dataset_reader.next_batch(
                    FLAGS.batch_size)
                valid_loss = sess.run(loss,
                                      feed_dict={
                                          image: valid_images,
                                          annotation: valid_annotations,
                                          keep_probability: 1.0
                                      })
                print("%s ---> Validation_loss: %g" %
                      (datetime.datetime.now(), valid_loss))
                saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)

    elif FLAGS.mode == "visualize":
        valid_images, valid_annotations = validation_dataset_reader.get_random_batch(
            FLAGS.batch_size)

        t_start = time.time()
        pred = sess.run(pred_annotation,
                        feed_dict={
                            image: valid_images,
                            annotation: valid_annotations,
                            keep_probability: 1.0
                        })
        '''t_elapsed = time.time() - t_start
        speed = len(pred)/t_elapsed
        print('the num of picture',len(pred))
        print('speed',speed)'''

        valid_annotations = np.squeeze(valid_annotations, axis=3)
        pred = np.squeeze(pred, axis=3)

        t_elapsed_s = time.time() - t_start
        speed_s = len(pred) / t_elapsed_s
        print('the num of picture', len(pred))
        print('speed_neddle', speed_s)

        #fit_ellipse
        '''for itr in range(FLAGS.batch_size):
            valid_images_=fit_ellipse(valid_images[itr],pred[itr])
            utils.save_image(valid_images_.astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5+itr))
            utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.logs_dir, name="gt_" + str(5+itr))
            utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
            print("Saved image: %d" % itr)'''

        #label_predict_pixel
        for itr in range(FLAGS.batch_size):
            valid_images_ = label_visualize(valid_images[itr], pred[itr])
            utils.save_image(valid_images_.astype(np.uint8),
                             FLAGS.logs_dir,
                             name="inp_" + str(5 + itr))
            utils.save_image(valid_annotations[itr].astype(np.uint8),
                             FLAGS.logs_dir,
                             name="gt_" + str(5 + itr))
            utils.save_image(pred[itr].astype(np.uint8),
                             FLAGS.logs_dir,
                             name="pred_" + str(5 + itr))
            print("Saved image: %d" % itr)

        #the origin fcn validation
        '''for itr in range(FLAGS.batch_size):
            utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5+itr))
            utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.logs_dir, name="gt_" + str(5+itr))
            utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
            print("Saved image: %d" % itr)'''

        t_elapsed_ = time.time() - t_start
        speed = len(pred) / t_elapsed_
        print('the num of picture', len(pred))
        print('speed_neddle', speed)
Пример #4
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    learning_rate = tf.placeholder(tf.float32, name="learning_rate")
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
    annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="annotation")

    pred_annotation_value, pred_annotation, logits = inference(image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2)
    tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)

    at = float(FLAGS.at)
    a_w = (1- 2*at) * tf.cast(tf.squeeze(annotation, squeeze_dims=[3]), tf.float32) + at
    #a_w = tf.reduce_sum(a_w_ * tf.one_hot(tf.squeeze(annotation, squeeze_dims=[1]), FLAGS.NUM_OF_CLASSESS), 1)
    

    pro = tf.nn.softmax(logits)
     
    loss_weight = tf.pow(1-tf.reduce_sum(pro * tf.one_hot(tf.squeeze(annotation, squeeze_dims=[3]), FLAGS.NUM_OF_CLASSESS), 3), FLAGS.gamma)
     
    
    loss = tf.reduce_mean(loss_weight * a_w * tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                                       labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                                       name="entropy"))

    valid_loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                          labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                          name="valid_entropy")))
   
    tf.summary.scalar("train_loss", loss)
    tf.summary.scalar("valid_loss", valid_loss)
    tf.summary.scalar("learning_rate", learning_rate)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var, learning_rate)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()
    
    #Check if has the log file
    if not os.path.exists(FLAGS.logs_dir):
        print("The logs path '%s' is not found" % FLAGS.logs_dir)
        print("Create now..")
        os.makedirs(FLAGS.logs_dir)
        print("%s is created successfully!" % FLAGS.logs_dir)

    #Create a file to write logs.
    filename="logs_%s%s.txt"%(FLAGS.mode,datetime.datetime.now())
    path_=os.path.join(FLAGS.logs_dir,filename)
    with open(path_,'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" % datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)
        logs_file.write("The mode is %s\n"% (FLAGS.mode))
        logs_file.write("The train data batch size is %d and the validation batch size is %d\n."%(FLAGS.batch_size,FLAGS.v_batch_size))
        logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" % (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    
    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(FLAGS.data_dir)
    print('number of train_records',len(train_records))
    print('number of valid_records',len(valid_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))

    path_lr = FLAGS.learning_rate_path
    with open(path_lr, 'r') as f:
        lr_ = float(f.readline().split('\n')[0])


    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_=0
    current_itr = FLAGS.train_itr
    for itr in xrange(current_itr, MAX_ITERATION):

        with open(path_lr, 'r') as f:
            lr_ = float(f.readline().split('\n')[0])

        train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
        feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85, learning_rate: lr_}
        

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')  
        if itr % 10 == 0:
            train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g" % (itr, train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % 500 == 0:
            
            #Caculate the accurary at the training set.
            train_random_images, train_random_annotations = train_dataset_reader.get_random_batch_for_train(FLAGS.v_batch_size)
            train_loss,train_pred_anno = sess.run([loss,pred_annotation], feed_dict={image:train_random_images,
                                                                                        annotation:train_random_annotations,
                                                                                        keep_probability:1.0})
            accu_iou_,accu_pixel_ = accu.caculate_accurary(train_pred_anno, train_random_annotations)
            print("%s ---> Training_loss: %g" % (datetime.datetime.now(), train_loss))
            print("%s ---> Training_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel_))
            print("%s ---> Training_iou_accuary: %g" % (datetime.datetime.now(),accu_iou_))
            print("---------------------------")
            #Output the logs.
            num_ = num_ + 1
            logs_file.write("No.%d the itr number is %d.\n" % (num_, itr))
            logs_file.write("%s ---> Training_loss: %g.\n" % (datetime.datetime.now(), train_loss))
            logs_file.write("%s ---> Training_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel_))
            logs_file.write("%s ---> Training_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou_))
            logs_file.write("---------------------------\n")

            valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.v_batch_size)
            #valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
                                                       #keep_probability: 1.0})
            valid_loss_,pred_anno=sess.run([valid_loss,pred_annotation],feed_dict={image:valid_images,
                                                                                      annotation:valid_annotations,
                                                                                      keep_probability:1.0})
            accu_iou,accu_pixel=accu.caculate_accurary(pred_anno,valid_annotations)
            print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss_))
            print("%s ---> Validation_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" % (datetime.datetime.now(),accu_iou))

            #Output the logs.
            logs_file.write("%s ---> Validation_loss: %g.\n" % (datetime.datetime.now(), valid_loss_))
            logs_file.write("%s ---> Validation_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
            #End the iterator
        logs_file.close()
Пример #5
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
    soft_annotation = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 2], name="soft_annotation")
    hard_annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="hard_annotation")

    pred_annotation_value, pred_annotation, logits = inference(image, keep_probability)
    #tf.summary.image("input_image", image, max_outputs=2)
    #tf.summary.image("ground_truth", tf.cast(soft_annotation, tf.uint8), max_outputs=2)
    #tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    #loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
    #                                                                      labels=tf.squeeze(annotation, squeeze_dims=[3]),
    #                                                                      name="entropy")))
    #The update is not finished.?????????????????????????????????????!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    '''
    soft_logits = tf.nn.softmax(logits/FLAGS.temperature)
    soft_loss = tf.reduce_mean((tf.nn.sigmoid_cross_entropy_with_logits(logits=soft_logits,
                                                                 #labels = tf.squeeze(soft_annotation, squeeze_dims=[3]),
                                                                  labels = soft_annotation,
                                                                  name = "entropy_soft")))
    #use weight
    soft_loss0 = 0.8*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=soft_logits[:,:,:,0],
                                                         labels=soft_annotation[:,:,:,0],
                                                         name ="entropy_soft0"))
    soft_loss1 = 0.2*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=soft_logits[:,:,:,1],
                                                         labels=soft_annotation[:,:,:,1],
                                                         name ="entropy_soft1"))
    soft_loss = tf.add(soft_loss0, soft_loss1)
    '''
    soft_logits = tf.nn.softmax(logits/FLAGS.temperature)
    soft_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits/FLAGS.temperature,
                                                                        labels = soft_annotation,
                                                                        name = "entropy_soft"))
    hard_loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.squeeze(hard_annotation, squeeze_dims=[3]),
                                                                  name="entropy_hard")))

    tf.summary.scalar("entropy", soft_loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(soft_loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()
    
    #Check if has the log file
    if not os.path.exists(FLAGS.logs_dir):
        print("The logs path '%s' is not found" % FLAGS.logs_dir)
        print("Create now..")
        os.makedirs(FLAGS.logs_dir)
        print("%s is created successfully!" % FLAGS.logs_dir)

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datatime.datatime.now()) + '.txt'
    filename="logs_%s%s.txt"%(FLAGS.mode,datetime.datetime.now())
    path_=os.path.join(FLAGS.logs_dir,filename)
    with open(path_,'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" % datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)
        logs_file.write("The mode is %s\n"% (FLAGS.mode))
        logs_file.write("The train data batch size is %d and the validation batch size is %d\n."%(FLAGS.batch_size,FLAGS.v_batch_size))
        logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" % (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    
    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(FLAGS.data_dir)
    print('number of train_records',len(train_records))
    print('number of valid_records',len(valid_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset_soft.BatchDatset(train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_=0
    for itr in xrange(MAX_ITERATION):
        
        train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
        train_annotations_n = train_annotations/FLAGS.normal
        feed_dict = {image: train_images, soft_annotation: train_annotations_n, keep_probability: 0.85}

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')  
        if itr % 10 == 0:
            '''
            num_ = 0
            for ii in range(224):
                for jj in range(224):
                    #if train_annotations_n[0][ii][jj][1] > train_annotations_n[0][ii][jj][0]:
                    if train_annotations_n[0][ii][jj][1] > 0.85:
                        #print('anno',ii,', ', jj,': ', train_annotations_n[0][ii][jj])
                        num_ = num_+1
            print("the number of dim1>0.85: %d" % num_)
            '''
            soft_train_logits, train_logits = sess.run([soft_logits, logits], feed_dict=feed_dict)
            train_logits = np.array(train_logits)
            soft_train_logits = np.array(soft_train_logits)
            '''
            num_ = 0
            for ii in range(224):
                for jj in range(224):
                    if train_annotations_n[0][ii][jj][1] > 0.85:
                    #if soft_train_logits[0][ii][jj][1] > soft_train_logits[0][ii][jj][0]:
                    #if soft_train_logits[0][ii][jj][1] > 0.82:
                        #print('logtis',ii,', ', jj,': ', train_logits[0][ii][jj])
                        print('soft_logtis',ii,', ', jj,': ', soft_train_logits[0][ii][jj])
                        print('anno       ',ii,', ', jj,': ', train_annotations_n[0][ii][jj])
                        print('--------------------')
                    if soft_train_logits[0][ii][jj][1] > 0.82:
                        num_ = num_+1
            print("the number of dim1>0.82: %d" % num_)
            '''

            train_loss, summary_str = sess.run([soft_loss, summary_op], feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g" % (itr, train_loss))
            logs_file.write("Step: %d, Train_loss:%g\n" % (itr, train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % 500 == 0:
            
            #Caculate the accurary at the training set.
            train_random_images, train_random_annotations = train_dataset_reader.get_random_batch_for_train(FLAGS.v_batch_size)
            train_logits,train_loss,train_pred_anno = sess.run([soft_logits,soft_loss,pred_annotation], feed_dict={image:train_random_images,
                                                                                        soft_annotation:train_random_annotations/FLAGS.normal,
                                                                                        keep_probability:1.0})
            #accu_iou_,accu_pixel_ = accu.caculate_accurary(train_pred_anno, train_random_annotations/100)
            print("%s ---> Training_loss: %g" % (datetime.datetime.now(), train_loss))
            #print("%s ---> Training_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel_))
            #print("%s ---> Training_iou_accuary: %g" % (datetime.datetime.now(),accu_iou_))
            print("---------------------------")
            #Output the logs.
            num_ = num_ + 1
            logs_file.write("No.%d the itr number is %d.\n" % (num_, itr))
            logs_file.write("%s ---> Training_loss: %g.\n" % (datetime.datetime.now(), train_loss))
            #logs_file.write("%s ---> Training_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel_))
            #logs_file.write("%s ---> Training_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou_))
            logs_file.write("---------------------------\n")

            valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.v_batch_size)
            valid_loss,pred_anno=sess.run([hard_loss,pred_annotation],feed_dict={image:valid_images,
                                                                                      hard_annotation:valid_annotations,
                                                                                      keep_probability:1.0})
            accu_iou,accu_pixel=accu.caculate_accurary(pred_anno,valid_annotations)
            print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
            print("%s ---> Validation_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" % (datetime.datetime.now(),accu_iou))

            #Output the logs.
            logs_file.write("%s ---> Validation_loss: %g.\n" % (datetime.datetime.now(), valid_loss))
            logs_file.write("%s ---> Validation_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
            #End the iterator
        logs_file.close()