示例#1
0
    def __init__(self, mode, max_epochs, batch_size, n_classes, train_records,
                 valid_records, im_sz, init_lr, keep_prob, logs_dir):

        FCNNet.__init__(self, mode, max_epochs, batch_size, n_classes,
                        train_records, valid_records, im_sz, init_lr,
                        keep_prob, logs_dir)

        self.cur_batch_size = tf.placeholder(dtype=tf.int32,
                                             name='cur_batch_size')
        #mask
        self.seq_num = cfgs.seq_num
        self.cur_channel = cfgs.cur_channel
        self.channel = self.cur_channel + self.seq_num
        self.inference_name = 'inference'
        self.images = tf.placeholder(tf.float32,
                                     shape=[
                                         None, self.IMAGE_SIZE[0],
                                         self.IMAGE_SIZE[1],
                                         cfgs.seq_num + self.cur_channel
                                     ],
                                     name='input_image')
        self.create_view_path()

        accu.create_ellipse_f()
        self.e_acc = accu.Ellip_acc()
示例#2
0
 def calculate_acc(self,
                   im,
                   filenames,
                   pred_anno,
                   anno,
                   gt_ellip_info,
                   if_valid=False):
     with tf.name_scope('ellip_accu'):
         self.accu_iou, self.accu = accu.caculate_accurary(pred_anno, anno)
         self.ellip_acc = accu.caculate_ellip_accu(im, filenames, pred_anno,
                                                   gt_ellip_info, if_valid)
示例#3
0
    def __init__(self, mode, max_epochs, batch_size, n_classes, train_records, valid_records, im_sz, init_lr, keep_prob, logs_dir):

        FCNNet.__init__(self, mode, max_epochs, batch_size, n_classes, train_records, valid_records, im_sz, init_lr, keep_prob, logs_dir)

        #mask
        self.seq_num = cfgs.seq_num
        self.cur_channel = cfgs.cur_channel
        self.channel = 3+self.seq_num
        self.inference_name = 'soft_infer'
        self.images = tf.placeholder(tf.float32, shape=[None, self.IMAGE_SIZE, self.IMAGE_SIZE, cfgs.seq_num+self.cur_channel], name='input_image')
        accu.create_ellipse_f()
    def __init__(self, mode, max_epochs, batch_size, n_classes, train_records,
                 valid_records, im_sz, init_lr, keep_prob, logs_dir):

        FCNNet.__init__(self, mode, max_epochs, batch_size, n_classes,
                        train_records, valid_records, im_sz, init_lr,
                        keep_prob, logs_dir)

        #mask
        self.seq_num = cfgs.seq_num
        self.cur_channel = cfgs.cur_channel
        self.channel = self.cur_channel + self.seq_num
        self.inference_name = 'inference'
        self.images = tf.placeholder(tf.float32,
                                     shape=[
                                         None, self.IMAGE_SIZE[0],
                                         self.IMAGE_SIZE[1],
                                         cfgs.seq_num + self.cur_channel
                                     ],
                                     name='input_image')
        self.grad_ims = tf.placeholder(
            tf.float32,
            shape=[None, self.IMAGE_SIZE[0], self.IMAGE_SIZE[1], 1],
            name='grad_image')

        self.create_view_path()
        self.coord_map_x, self.coord_map_y = self.generate_coord_map(
            self.batch_size)
        self.coord_x_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.IMAGE_SIZE[0], self.IMAGE_SIZE[1]],
            name='coord_x_map_tensor')
        self.coord_y_tensor = tf.placeholder(
            tf.float32,
            shape=[None, self.IMAGE_SIZE[0], self.IMAGE_SIZE[1]],
            name='coord_y_map_tensor')

        self.ellip_low = tf.placeholder(tf.float32,
                                        shape=[None],
                                        name='ellipse_info_lower_axis')
        self.ellip_high = tf.placeholder(tf.float32,
                                         shape=[None],
                                         name='ellipse_info_higher_axis')
        self.ellip_axis = tf.placeholder(tf.float32,
                                         shape=[None],
                                         name='ellipse_info_mean_axis')

        accu.create_ellipse_f()
        self.e_acc = accu.Ellip_acc()
示例#5
0
    def __init__(self, mode, max_epochs, batch_size, n_classes, train_records,
                 valid_records, im_sz, init_lr, keep_prob, logs_dir):

        FCNNet.__init__(self, mode, max_epochs, batch_size, n_classes,
                        train_records, valid_records, im_sz, init_lr,
                        keep_prob, logs_dir)

        #seq_mask(short for sm)
        self.seq_num = cfgs.seq_num
        self.cur_channel = cfgs.cur_channel
        self.sm_channel = self.cur_channel + self.seq_num
        self.sm_infer_name = 'inference'
        self.sm_images = tf.placeholder(tf.float32,
                                        shape=[
                                            None, self.IMAGE_SIZE,
                                            self.IMAGE_SIZE,
                                            cfgs.seq_num + self.cur_channel
                                        ],
                                        name='seq_mask_input_image')
        self.sm_annos = tf.placeholder(tf.int32,
                                       shape=[
                                           None, self.IMAGE_SIZE,
                                           self.IMAGE_SIZE,
                                           cfgs.seq_mask_anno_channel
                                       ],
                                       name='seq_mask_annos')
        #soft
        self.soft_infer_name = 'soft_infer'
        self.soft_channel = 3
        self.soft_anno_channel = cfgs.soft_anno_channel
        self.soft_images = tf.placeholder(
            tf.float32,
            shape=[None, self.IMAGE_SIZE, self.IMAGE_SIZE, self.soft_channel],
            name='soft_input_images')
        self.soft_annos = tf.placeholder(tf.float32,
                                         shape=[
                                             None, self.IMAGE_SIZE,
                                             self.IMAGE_SIZE,
                                             self.soft_anno_channel
                                         ],
                                         name='soft_annos')

        accu.create_ellipse_f()
        self.create_view_f()
        self.e_acc = accu.Ellip_acc()
示例#6
0
    def calculate_acc(self,
                      im,
                      filenames,
                      pred_anno,
                      pred_pro,
                      anno,
                      gt_ellip_info,
                      if_valid=False):
        with tf.name_scope('ellip_accu'):
            if cfgs.test_accu:
                self.accu_iou, self.accu = accu.caculate_accurary(
                    pred_anno, anno)

                #ellipse loss
                #self.ellip_acc = accu.caculate_ellip_accu(im, filenames, pred_anno, pred_pro, gt_ellip_info, if_valid)
                #Hausdorff loss
                self.ellip_acc = self.e_acc.caculate_ellip_accu(
                    im, filenames, pred_anno, pred_pro, gt_ellip_info,
                    if_valid)

            else:
                self.accu_iou = 0
                self.accu = 0
                self.ellip_acc = 0
示例#7
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")
    soft_annotation = tf.placeholder(tf.float32,
                                     shape=[None, IMAGE_SIZE, IMAGE_SIZE, 2],
                                     name="soft_annotation")
    pred_annotation_value, pred_annotation, logits, pred_prob = inference(
        image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datetime.datetime.now()) + '.txt'
    filename = "logs_%s%s.txt" % (FLAGS.mode, datetime.datetime.now())
    path_ = os.path.join(FLAGS.logs_dir, filename)
    logs_file = open(path_, 'w')
    logs_file.write("The logs file is created at %s\n" %
                    datetime.datetime.now())
    logs_file.write("The mode is %s\n" % (FLAGS.mode))
    logs_file.write(
        "The train data batch size is %d and the validation batch size is %d.\n"
        % (FLAGS.batch_size, FLAGS.v_batch_size))
    logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
    logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" %
                    (IMAGE_SIZE, MAX_ITERATION))
    logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)

    print("Setting up image reader...")
    logs_file.write("Setting up image reader...\n")
    train_records, valid_records = scene_parsing.my_read_dataset(
        FLAGS.data_dir)
    print('number of train_records', len(train_records))
    print('number of valid_records', len(valid_records))
    logs_file.write('number of train_records %d\n' % len(train_records))
    logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'check_training':
        train_dataset_reader = dataset_soft.BatchDatset(
            train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if FLAGS.mode == "accurary":
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filenames, if_con, start, end = validation_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            valid_loss, pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            accu_iou, accu_pixel = accu.caculate_accurary(
                pred_anno, valid_annotations)
            print("Ture %d ---> the data from %d to %d" % (count, start, end))
            print("%s ---> Validation_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou))
            #Output logs.
            logs_file.write("Ture %d ---> the data from %d to %d\n" %
                            (count, start, end))
            logs_file.write("%s ---> Validation_pixel_accuary: %g\n" %
                            (datetime.datetime.now(), accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g\n" %
                            (datetime.datetime.now(), accu_iou))

            accu_iou_t = accu_iou_t + accu_iou
            accu_pixel_t = accu_pixel_t + accu_pixel
        print("%s ---> Total validation_pixel_accuary: %g" %
              (datetime.datetime.now(), accu_pixel_t / count))
        print("%s ---> Total validation_iou_accuary: %g" %
              (datetime.datetime.now(), accu_iou_t / count))
        #Output logs
        logs_file.write("%s ---> Total validation_pixel_accurary: %g\n" %
                        (datetime.datetime.now(), accu_pixel_t / count))
        logs_file.write("%s ---> Total validation_iou_accurary: %g\n" %
                        (datetime.datetime.now(), accu_iou_t / count))

    elif FLAGS.mode == "all_visualize":

        re_save_dir = "%s%s" % (FLAGS.result_dir, datetime.datetime.now())
        logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
        logs_file.write("The number of part visualization is %d.\n" %
                        FLAGS.v_batch_size)

        #Check the result path if exists.
        if not os.path.exists(re_save_dir):
            print("The path '%s' is not found." % re_save_dir)
            print("Create now ...")
            os.makedirs(re_save_dir)
            print("Create '%s' successfully." % re_save_dir)
            logs_file.write("Create '%s' successfully.\n" % re_save_dir)
        re_save_dir_anno = os.path.join(re_save_dir, 'anno')
        re_save_dir_pred = os.path.join(re_save_dir, 'pred')
        re_save_dir_train_heat = os.path.join(re_save_dir, 'train_heat')
        re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
        re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
        re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
        if not os.path.exists(re_save_dir_anno):
            os.makedirs(re_save_dir_anno)
        if not os.path.exists(re_save_dir_pred):
            os.makedirs(re_save_dir_pred)
        if not os.path.exists(re_save_dir_train_heat):
            os.makedirs(re_save_dir_train_heat)
        if not os.path.exists(re_save_dir_heat):
            os.makedirs(re_save_dir_heat)
        if not os.path.exists(re_save_dir_ellip):
            os.makedirs(re_save_dir_ellip)
        if not os.path.exists(re_save_dir_transheat):
            os.makedirs(re_save_dir_transheat)
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filename, if_con, start, end = validation_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            pred_value, pred, logits_, pred_prob_ = sess.run(
                [pred_annotation_value, pred_annotation, logits, pred_prob],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            #valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)
            pred_value = np.squeeze(pred_value, axis=3)
            pred_ellip = np.argmax(pred_prob_ + [0.85, 0], axis=3)
            #label_predict_pixel
            for itr in range(len(pred)):
                filename = valid_filename[itr]['filename']
                if FLAGS.anno == 'T':
                    valid_images_anno = anno_visualize(
                        valid_images[itr].copy(), valid_annotations[itr, :, :,
                                                                    1])
                    utils.save_image(valid_images_anno.astype(np.uint8),
                                     re_save_dir_anno,
                                     name="anno_" + filename)
                if FLAGS.pred == 'T':
                    valid_images_pred = pred_visualize(
                        valid_images[itr].copy(),
                        np.expand_dims(pred_ellip[itr], axis=2))
                    utils.save_image(valid_images_pred.astype(np.uint8),
                                     re_save_dir_pred,
                                     name="pred_" + filename)
                if FLAGS.train_heat == 'T':
                    heat_map = density_heatmap(
                        valid_annotations[itr, :, :, 1] / FLAGS.normal)
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_train_heat,
                                     name="trainheat_" + filename)

                if FLAGS.fit_ellip == 'T':
                    valid_images_ellip = fit_ellipse_findContours(
                        valid_images[itr].copy(),
                        np.expand_dims(pred_ellip[itr],
                                       axis=2).astype(np.uint8))
                    utils.save_image(valid_images_ellip.astype(np.uint8),
                                     re_save_dir_ellip,
                                     name="ellip_" + filename)
                if FLAGS.heatmap == 'T':
                    heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_heat,
                                     name="heat_" + filename)
                if FLAGS.trans_heat == 'T':
                    trans_heat_map = translucent_heatmap(
                        valid_images[itr],
                        heat_map.astype(np.uint8).copy())
                    utils.save_image(trans_heat_map,
                                     re_save_dir_transheat,
                                     name="trans_heat_" + filename)

    elif FLAGS.mode == 'check_training':
        re_save_dir = "%s%s" % (FLAGS.result_dir, datetime.datetime.now())
        logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
        logs_file.write("The number of part visualization is %d.\n" %
                        FLAGS.v_batch_size)

        #Check the result path if exists.
        if not os.path.exists(re_save_dir):
            print("The path '%s' is not found." % re_save_dir)
            print("Create now ...")
            os.makedirs(re_save_dir)
            print("Create '%s' successfully." % re_save_dir)
            logs_file.write("Create '%s' successfully.\n" % re_save_dir)
        re_save_dir_anno = os.path.join(re_save_dir, 'anno')
        re_save_dir_pred = os.path.join(re_save_dir, 'pred')
        re_save_dir_train_heat = os.path.join(re_save_dir, 'train_heat')
        re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
        re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
        re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
        if not os.path.exists(re_save_dir_anno):
            os.makedirs(re_save_dir_anno)
        if not os.path.exists(re_save_dir_pred):
            os.makedirs(re_save_dir_pred)
        if not os.path.exists(re_save_dir_train_heat):
            os.makedirs(re_save_dir_train_heat)
        if not os.path.exists(re_save_dir_heat):
            os.makedirs(re_save_dir_heat)
        if not os.path.exists(re_save_dir_ellip):
            os.makedirs(re_save_dir_ellip)
        if not os.path.exists(re_save_dir_transheat):
            os.makedirs(re_save_dir_transheat)
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filename, if_con, start, end = train_dataset_reader.next_batch_valid(
                FLAGS.v_batch_size)
            pred_value, pred, logits_, pred_prob_ = sess.run(
                [pred_annotation_value, pred_annotation, logits, pred_prob],
                feed_dict={
                    image: valid_images,
                    soft_annotation: valid_annotations,
                    keep_probability: 1.0
                })
            #valid_annotations = np.squeeze(valid_annotations, axis=3)
            pred = np.squeeze(pred, axis=3)
            pred_value = np.squeeze(pred_value, axis=3)

            #label_predict_pixel
            for itr in range(len(pred)):
                filename = valid_filename[itr]['filename']
                if FLAGS.anno == 'T':
                    valid_images_anno = anno_visualize(
                        valid_images[itr].copy(), valid_annotations[itr, :, :,
                                                                    1])
                    utils.save_image(valid_images_anno.astype(np.uint8),
                                     re_save_dir_anno,
                                     name="anno_" + filename)
                if FLAGS.pred == 'T':
                    valid_images_pred = soft_pred_visualize(
                        valid_images[itr].copy(), pred[itr])
                    utils.save_image(valid_images_pred.astype(np.uint8),
                                     re_save_dir_pred,
                                     name="pred_" + filename)
                if FLAGS.train_heat == 'T':
                    heat_map = density_heatmap(
                        valid_annotations[itr, :, :, 1] / FLAGS.normal)
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_train_heat,
                                     name="trainheat_" + filename)

                if FLAGS.fit_ellip == 'T':
                    valid_images_ellip = fit_ellipse_findContours(
                        valid_images[itr].copy(),
                        np.expand_dims(pred[itr], axis=2).astype(np.uint8))
                    utils.save_image(valid_images_ellip.astype(np.uint8),
                                     re_save_dir_ellip,
                                     name="ellip_" + filename)
                if FLAGS.heatmap == 'T':
                    heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_heat,
                                     name="heat_" + filename)
                if FLAGS.trans_heat == 'T':
                    trans_heat_map = translucent_heatmap(
                        valid_images[itr],
                        heat_map.astype(np.uint8).copy())
                    utils.save_image(trans_heat_map,
                                     re_save_dir_transheat,
                                     name="trans_heat_" + filename)

    logs_file.close()
    if FLAGS.mode == "check_training" or FLAGS.mode == "all_visualize":
        result_logs_file = os.path.join(re_save_dir, filename)
        shutil.copyfile(path_, result_logs_file)
示例#8
0
 def calculate_acc(self, pred_anno, anno):
     with tf.name_scope('accu'):
         self.accu_iou, self.accu = accu.caculate_accurary(pred_anno, anno)
示例#9
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")

    pred_annotation_value, pred_annotation, logits = inference(
        image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    #Check if has the log file
    if not os.path.exists(FLAGS.logs_dir):
        print("The logs path '%s' is not found" % FLAGS.logs_dir)
        print("Create now..")
        os.makedirs(FLAGS.logs_dir)
        print("%s is created successfully!" % FLAGS.logs_dir)

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datatime.datatime.now()) + '.txt'
    filename = "logs_%s%s.txt" % (FLAGS.mode, datetime.datetime.now())
    path_ = os.path.join(FLAGS.logs_dir, filename)
    with open(path_, 'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" %
                        datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)
        logs_file.write("The mode is %s\n" % (FLAGS.mode))
        logs_file.write(
            "The train data batch size is %d and the validation batch size is %d\n."
            % (FLAGS.batch_size, FLAGS.v_batch_size))
        logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" %
                        (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(
        FLAGS.data_dir)
    print('number of train_records', len(train_records))
    print('number of valid_records', len(valid_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records,
                                                   image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_ = 0
    for itr in xrange(MAX_ITERATION):

        train_images, train_annotations = train_dataset_reader.next_batch(
            FLAGS.batch_size)
        feed_dict = {
            image: train_images,
            annotation: train_annotations,
            keep_probability: 0.85
        }

        for ii in range(224):
            for jj in range(224):
                if train_annotations[0][ii][jj] > 0:
                    print('anno', ii, ', ', jj, ': ',
                          train_annotations[0][ii][jj])

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')
        if itr % 10 == 0:
            train_loss, summary_str = sess.run([loss, summary_op],
                                               feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g" % (itr, train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % 500 == 0:

            #Caculate the accurary at the training set.
            train_random_images, train_random_annotations = train_dataset_reader.get_random_batch_for_train(
                FLAGS.v_batch_size)
            train_loss, train_pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: train_random_images,
                    annotation: train_random_annotations,
                    keep_probability: 1.0
                })
            accu_iou_, accu_pixel_ = accu.caculate_accurary(
                train_pred_anno, train_random_annotations)
            print("%s ---> Training_loss: %g" %
                  (datetime.datetime.now(), train_loss))
            print("%s ---> Training_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel_))
            print("%s ---> Training_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou_))
            print("---------------------------")
            #Output the logs.
            num_ = num_ + 1
            logs_file.write("No.%d the itr number is %d.\n" % (num_, itr))
            logs_file.write("%s ---> Training_loss: %g.\n" %
                            (datetime.datetime.now(), train_loss))
            logs_file.write("%s ---> Training_pixel_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_pixel_))
            logs_file.write("%s ---> Training_iou_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_iou_))
            logs_file.write("---------------------------\n")

            valid_images, valid_annotations = validation_dataset_reader.next_batch(
                FLAGS.v_batch_size)
            #valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
            #keep_probability: 1.0})
            valid_loss, pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            accu_iou, accu_pixel = accu.caculate_accurary(
                pred_anno, valid_annotations)
            print("%s ---> Validation_loss: %g" %
                  (datetime.datetime.now(), valid_loss))
            print("%s ---> Validation_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou))

            #Output the logs.
            logs_file.write("%s ---> Validation_loss: %g.\n" %
                            (datetime.datetime.now(), valid_loss))
            logs_file.write("%s ---> Validation_pixel_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g.\n" %
                            (datetime.datetime.now(), accu_iou))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
        '''         
        if itr % 10000 == 0:
                
            count=0
            if_con=True
            accu_iou_t=0
            accu_pixel_t=0
        
            while if_con:
                count=count+1
                valid_images, valid_annotations, filenames, if_con, start, end=validation_dataset_reader.next_batch_valid(FLAGS.v_batch_size)
                valid_loss,pred_anno=sess.run([loss,pred_annotation],feed_dict={image:valid_images,
                                                                                      annotation:valid_annotations,
                                                                                      keep_probability:1.0})
                accu_iou,accu_pixel=accu.caculate_accurary(pred_anno,valid_annotations)
            
                accu_iou_t=accu_iou_t+accu_iou
                accu_pixel_t=accu_pixel_t+accu_pixel
            print("No.%d toal validation data accurary" % itr)
            print("%s ---> Total validation_pixel_accurary: %g" % (datetime.datetime.now(),accu_pixel_t/count))
            print("%s ---> Total validation_iou_accurary: %g" % (datetime.datetime.now(),accu_iou_t/count))
            #Write logs file
            logs_file.write("No.%d toal validation data accurary.\n" % itr)
            logs_file.write("%s ---> Total validation_pixel_accurary: %g.\n" % (datetime.datetime.now(),accu_pixel_t/count))
            logs_file.write("%s ---> Total validation_iou_accurary: %g.\n" % (datetime.datetime.now(),accu_iou_t/count))'''
        #End the iterator
        logs_file.close()
示例#10
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    learning_rate = tf.placeholder(tf.float32, name="learning_rate")
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 7], name="input_image")
    annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="annotation")

    pred_annotation_value, pred_annotation, logits = inference(image, keep_probability)
    #logits:the last layer of conv net
    #labels:the ground truth
    
    at = float(cfgs.at)
    a_w = (1- 2*at) * tf.cast(tf.squeeze(annotation, squeeze_dims=[3]), tf.float32) + at
    

    pro = tf.nn.softmax(logits)
     
    loss_weight = tf.pow(1-tf.reduce_sum(pro * tf.one_hot(tf.squeeze(annotation, squeeze_dims=[3]), cfgs.NUM_OF_CLASSESS), 3), cfgs.gamma)
     
    
    loss = tf.reduce_mean(loss_weight * a_w * tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                                       labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                                       name="entropy"))


    valid_loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                          labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                          name="valid_entropy")))
   
    tf.summary.scalar("train_loss", loss)
    tf.summary.scalar("valid_loss", valid_loss)
    tf.summary.scalar("learning_rate", learning_rate)

    trainable_var = tf.trainable_variables()
    if cfgs.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var, learning_rate)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()
    
    #Check if has the log file
    if not os.path.exists(cfgs.logs_dir):
        print("The logs path '%s' is not found" % cfgs.logs_dir)
        print("Create now..")
        os.makedirs(cfgs.logs_dir)
        print("%s is created successfully!" % cfgs.logs_dir)

    #Create a file to write logs.
    #filename='logs'+ cfgs.mode + str(datatime.datatime.now()) + '.txt'
    filename="logs_%s%s.txt"%(cfgs.mode,datetime.datetime.now())
    path_=os.path.join(cfgs.logs_dir,filename)
    with open(path_,'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" % datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % cfgs.logs_dir)
        logs_file.write("The mode is %s\n"% (cfgs.mode))
        logs_file.write("The train data batch size is %d and the validation batch size is %d\n."%(cfgs.batch_size,cfgs.v_batch_size))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" % (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    
    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(cfgs.seq_list_path, cfgs.anno_path)
    print('number of train_records',len(train_records))
    print('number of valid_records',len(valid_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))

    path_lr = cfgs.learning_rate_path
    with open(path_lr, 'r') as f:
        lr_ = float(f.readline().split('\n')[0])


    print("Setting up dataset reader")
    vis = True if cfgs.mode == 'all_visualize' else False
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE, 'visualize': vis, 'annotation':True}

    if cfgs.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records, image_options)
   
    validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)
    #save current train itr from stopping(init = 0)
    current_itr_var = tf.Variable(0, dtype=tf.int32, name='current_itr')
    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(cfgs.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(cfgs.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_=0
    current_itr = current_itr_var.eval(sess)
    print('itr\tmode\tlr_\tacc\tiou_acc\tloss')
    time_begin = time.time()
    valid_time = 100
    for itr in xrange(current_itr, MAX_ITERATION):
        
        with open(path_lr, 'r') as f:
            lr_ = float(f.readline().split('\n')[0])
        
        train_images, train_annotations, if_finish, epoch_no= train_dataset_reader.next_batch(cfgs.batch_size)
        if if_finish:
            print('Epochs%d finished, time comsumed:%.5f' % (epoch_no, time.time()-time_begin))
            time_begin = time.time()
        feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85, learning_rate: lr_}
        

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')  
        if itr % 10 == 0:
            train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
            print('%d\t%s\t%g\t%s\t%s\t%.6f' % (itr, 'train',lr_, 'None','None',train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % valid_time == 0:
            with open('time_to_valid.txt', 'r') as f:
                valid_time = int(f.readline().split('\n')[0])


            #Caculate the accurary at the training set.
            train_random_images, train_random_annotations = train_dataset_reader.get_random_batch_for_train(cfgs.v_batch_size)
            train_loss,train_pred_anno = sess.run([loss,pred_annotation], feed_dict={image:train_random_images,
                                                                                        annotation:train_random_annotations,
                                                                                        keep_probability:1.0})
            accu_iou_,accu_pixel_ = accu.caculate_accurary(train_pred_anno, train_random_annotations)
            print('%d\t%s\t%g\t%.5f\t%.5f\t%.6f' % (itr, 'train', lr_, accu_pixel_, accu_iou_, train_loss))
            #Output the logs.
            num_ = num_ + 1
            logs_file.write('%d\t%s\t%g\t%.5f\t%.5f\t%.6f\n' % (itr, 'train', lr_, accu_pixel_, accu_iou_, train_loss))

            valid_images, valid_annotations, _, _= validation_dataset_reader.next_batch(cfgs.v_batch_size)
            #valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
                                                       #keep_probability: 1.0})
            valid_loss_,pred_anno=sess.run([loss,pred_annotation],feed_dict={image:valid_images,
                                                                                      annotation:valid_annotations,
                                                                                      keep_probability:1.0})
            accu_iou,accu_pixel=accu.caculate_accurary(pred_anno,valid_annotations)
            print('%d\t%s\t%g\t%.5f\t%.5f\t%.6f' % (itr, 'valid', lr_, accu_pixel, accu_iou, valid_loss_))
            #Output the logs.
            logs_file.write('%d\t%s\t%g\t%.5f\t%.5f\t%.6f\n' % (itr, 'train', lr_, accu_pixel, accu_iou, valid_loss_))

            #record current itr
            sess.run(tf.assign(current_itr_var, itr))
            saver.save(sess, cfgs.logs_dir + "model.ckpt", itr)
            #End the iterator
        '''
        if itr % 10000 == 0 and itr > 0:
            lr_ = lr_/10
            print('learning rate change from %g to %g' %(lr_*10, lr_))
            logs_file.write('learning rate change from %g to %g\n' %(lr_*10, lr_))'''
        logs_file.close()
示例#11
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    learning_rate = tf.placeholder(tf.float32, name="learning_rate")
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
    annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="annotation")

    pred_annotation_value, pred_annotation, logits = inference(image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2)
    tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)

    at = float(FLAGS.at)
    a_w = (1- 2*at) * tf.cast(tf.squeeze(annotation, squeeze_dims=[3]), tf.float32) + at
    #a_w = tf.reduce_sum(a_w_ * tf.one_hot(tf.squeeze(annotation, squeeze_dims=[1]), FLAGS.NUM_OF_CLASSESS), 1)
    

    pro = tf.nn.softmax(logits)
     
    loss_weight = tf.pow(1-tf.reduce_sum(pro * tf.one_hot(tf.squeeze(annotation, squeeze_dims=[3]), FLAGS.NUM_OF_CLASSESS), 3), FLAGS.gamma)
     
    
    loss = tf.reduce_mean(loss_weight * a_w * tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                                       labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                                       name="entropy"))

    valid_loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                          labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                          name="valid_entropy")))
   
    tf.summary.scalar("train_loss", loss)
    tf.summary.scalar("valid_loss", valid_loss)
    tf.summary.scalar("learning_rate", learning_rate)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var, learning_rate)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()
    
    #Check if has the log file
    if not os.path.exists(FLAGS.logs_dir):
        print("The logs path '%s' is not found" % FLAGS.logs_dir)
        print("Create now..")
        os.makedirs(FLAGS.logs_dir)
        print("%s is created successfully!" % FLAGS.logs_dir)

    #Create a file to write logs.
    filename="logs_%s%s.txt"%(FLAGS.mode,datetime.datetime.now())
    path_=os.path.join(FLAGS.logs_dir,filename)
    with open(path_,'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" % datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)
        logs_file.write("The mode is %s\n"% (FLAGS.mode))
        logs_file.write("The train data batch size is %d and the validation batch size is %d\n."%(FLAGS.batch_size,FLAGS.v_batch_size))
        logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" % (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    
    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(FLAGS.data_dir)
    print('number of train_records',len(train_records))
    print('number of valid_records',len(valid_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))

    path_lr = FLAGS.learning_rate_path
    with open(path_lr, 'r') as f:
        lr_ = float(f.readline().split('\n')[0])


    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_=0
    current_itr = FLAGS.train_itr
    for itr in xrange(current_itr, MAX_ITERATION):

        with open(path_lr, 'r') as f:
            lr_ = float(f.readline().split('\n')[0])

        train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
        feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85, learning_rate: lr_}
        

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')  
        if itr % 10 == 0:
            train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g" % (itr, train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % 500 == 0:
            
            #Caculate the accurary at the training set.
            train_random_images, train_random_annotations = train_dataset_reader.get_random_batch_for_train(FLAGS.v_batch_size)
            train_loss,train_pred_anno = sess.run([loss,pred_annotation], feed_dict={image:train_random_images,
                                                                                        annotation:train_random_annotations,
                                                                                        keep_probability:1.0})
            accu_iou_,accu_pixel_ = accu.caculate_accurary(train_pred_anno, train_random_annotations)
            print("%s ---> Training_loss: %g" % (datetime.datetime.now(), train_loss))
            print("%s ---> Training_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel_))
            print("%s ---> Training_iou_accuary: %g" % (datetime.datetime.now(),accu_iou_))
            print("---------------------------")
            #Output the logs.
            num_ = num_ + 1
            logs_file.write("No.%d the itr number is %d.\n" % (num_, itr))
            logs_file.write("%s ---> Training_loss: %g.\n" % (datetime.datetime.now(), train_loss))
            logs_file.write("%s ---> Training_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel_))
            logs_file.write("%s ---> Training_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou_))
            logs_file.write("---------------------------\n")

            valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.v_batch_size)
            #valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
                                                       #keep_probability: 1.0})
            valid_loss_,pred_anno=sess.run([valid_loss,pred_annotation],feed_dict={image:valid_images,
                                                                                      annotation:valid_annotations,
                                                                                      keep_probability:1.0})
            accu_iou,accu_pixel=accu.caculate_accurary(pred_anno,valid_annotations)
            print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss_))
            print("%s ---> Validation_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" % (datetime.datetime.now(),accu_iou))

            #Output the logs.
            logs_file.write("%s ---> Validation_loss: %g.\n" % (datetime.datetime.now(), valid_loss_))
            logs_file.write("%s ---> Validation_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
            #End the iterator
        logs_file.close()
示例#12
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
    soft_annotation = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 2], name="soft_annotation")
    hard_annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="hard_annotation")

    pred_annotation_value, pred_annotation, logits = inference(image, keep_probability)
    #tf.summary.image("input_image", image, max_outputs=2)
    #tf.summary.image("ground_truth", tf.cast(soft_annotation, tf.uint8), max_outputs=2)
    #tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)
    #logits:the last layer of conv net
    soft_logits = tf.nn.softmax(logits/FLAGS.temperature)
    soft_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits/FLAGS.temperature,
                                                                        labels = soft_annotation,
                                                                        name = "entropy_soft"))
    hard_loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.squeeze(hard_annotation, squeeze_dims=[3]),
                                                                  name="entropy_hard")))

    tf.summary.scalar("entropy", soft_loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(soft_loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()
    
    #Check if has the log file
    if not os.path.exists(FLAGS.logs_dir):
        print("The logs path '%s' is not found" % FLAGS.logs_dir)
        print("Create now..")
        os.makedirs(FLAGS.logs_dir)
        print("%s is created successfully!" % FLAGS.logs_dir)

    #Create a file to write logs.
    filename="logs_%s%s.txt"%(FLAGS.mode,datetime.datetime.now())
    path_=os.path.join(FLAGS.logs_dir,filename)
    with open(path_,'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" % datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)
        logs_file.write("The mode is %s\n"% (FLAGS.mode))
        logs_file.write("The train data batch size is %d and the validation batch size is %d\n."%(FLAGS.batch_size,FLAGS.v_batch_size))
        logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" % (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    
    print("Setting up image reader...")
    train_records, valid_records, test_records = scene_parsing.read_dataset_tvt(FLAGS.data_dir)
    print('number of train_records',len(train_records))
    print('number of valid_records',len(valid_records))
    print('number of test records', len(test_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))
        logs_file.write('number of test_records %d\n' % len(test_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset_soft.BatchDatset(train_records, image_options)
        validation_dataset_reader = dataset_soft.BatchDatset(valid_records, image_options)
    test_dataset_reader = dataset.BatchDatset(test_records, image_options)
    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_=0
    for itr in xrange(MAX_ITERATION):
        
        train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
        train_annotations_n = train_annotations/FLAGS.normal
        feed_dict = {image: train_images, soft_annotation: train_annotations_n, keep_probability: 0.85}

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')  
        if itr % 10 == 0:
            soft_train_logits, train_logits = sess.run([soft_logits, logits], feed_dict=feed_dict)
            train_logits = np.array(train_logits)
            soft_train_logits = np.array(soft_train_logits)

            train_loss, summary_str = sess.run([soft_loss, summary_op], feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g" % (itr, train_loss))
            logs_file.write("Step: %d, Train_loss:%g\n" % (itr, train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % 500 == 0:
            
            #Caculate the accurary at the validation set.
            valid_images, valid_annotations = valid_dataset_reader.get_records()
            valid_logits,valid_loss,valid_pred_anno = sess.run([soft_logits,soft_loss,pred_annotation], feed_dict={image:valid_images,
                                                                                        soft_annotation:valid_annotations/FLAGS.normal,
                                                                                        keep_probability:1.0})
            print('shape of train_random_annotations', train_random_annotations.shape)
            #Caculate accurary
            choose_len = math.ceil((FLAGS.soft_thresh_upper - FLAGS.soft_thresh_lower)/0.01)
            test_thresh = FLAGS.soft_thresh_lower
            accu_sum_max = -1
            accu_iou_max = 0
            accu_pixel_max = 0
            choose_thresh = FLAGS.soft_thresh_lower
            choose_thresh = test_thresh
            print(choose_len)
            for i in range(choose_len):
                accu_iou_,accu_pixel_ = accu.caculate_soft_accurary(valid_logits, valid_annotations/FLAGS.normal, test_thresh)
                accu_sum = accu_iou_*FLAGS.w_iou + accu_pixel_*(1-FLAGS.w_iou)
                print('accu_sum: %g' % accu_sum)
                if accu_sum>accu_sum_max:
                    choose_thresh = test_thresh
                    accu_sum_max = accu_sum
                    accu_iou_max = accu_iou_
                    accu_pixel_max = accu_pixel_
                print('%d accu_iou:%g, accu_pixel:%g, accu_sum:%g, test_thresh:%g' % (i,accu_iou_max, accu_pixel_max,accu_sum_max,test_thresh))
                test_thresh += FLAGS.interval
            FLAGS.soft_thresh = choose_thresh
            print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), train_loss))
            print("%s ---> The chosen soft_threshhold is %g." % (datetime.datetime.now(),FLAGS.soft_thresh))
            print("%s ---> Validation_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel_max))
            print("%s ---> Validation_iou_accuary: %g" % (datetime.datetime.now(),accu_iou_max))

            print("---------------------------")
            #Output the logs.
            num_ = num_ + 1
            logs_file.write("No.%d the itr number is %d.\n" % (num_, itr))
            logs_file.write("%s ---> Validation_loss: %g.\n" % (datetime.datetime.now(), train_loss))
            logs_file.write("%s ---> The chosen soft_threshhold is %g.\n" % (datetime.datetime.now(),FLAGS.soft_thresh))
            logs_file.write("%s ---> Validation_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel_))
            logs_file.write("%s ---> Validation_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou_))
            logs_file.write("---------------------------\n")
            
            #test dataset
            test_images, test_annotations = test_dataset_reader.next_batch(FLAGS.v_batch_size)
            test_loss,pred_anno=sess.run([hard_loss,pred_annotation],feed_dict={image:test_images,
                                                                                      hard_annotation:test_annotations,
                                                                                      keep_probability:1.0})
            accu_iou,accu_pixel=accu.caculate_soft_accurary(pred_anno, test_annotations, FLAGS.soft_thresh)
            print("%s ---> test_loss: %g" % (datetime.datetime.now(), valid_loss))
            print("%s ---> test_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel))
            print("%s ---> test_iou_accuary: %g" % (datetime.datetime.now(),accu_iou))

            #Output the logs.
            logs_file.write("%s ---> test_loss: %g.\n" % (datetime.datetime.now(), valid_loss))
            logs_file.write("%s ---> test_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel))
            logs_file.write("%s ---> test_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
            #End the iterator
        logs_file.close()
示例#13
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
    soft_annotation = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 2], name="soft_annotation")
    hard_annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="hard_annotation")

    pred_annotation_value, pred_annotation, logits = inference(image, keep_probability)
    #tf.summary.image("input_image", image, max_outputs=2)
    #tf.summary.image("ground_truth", tf.cast(soft_annotation, tf.uint8), max_outputs=2)
    #tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    #loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
    #                                                                      labels=tf.squeeze(annotation, squeeze_dims=[3]),
    #                                                                      name="entropy")))
    #The update is not finished.?????????????????????????????????????!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    '''
    soft_logits = tf.nn.softmax(logits/FLAGS.temperature)
    soft_loss = tf.reduce_mean((tf.nn.sigmoid_cross_entropy_with_logits(logits=soft_logits,
                                                                 #labels = tf.squeeze(soft_annotation, squeeze_dims=[3]),
                                                                  labels = soft_annotation,
                                                                  name = "entropy_soft")))
    #use weight
    soft_loss0 = 0.8*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=soft_logits[:,:,:,0],
                                                         labels=soft_annotation[:,:,:,0],
                                                         name ="entropy_soft0"))
    soft_loss1 = 0.2*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=soft_logits[:,:,:,1],
                                                         labels=soft_annotation[:,:,:,1],
                                                         name ="entropy_soft1"))
    soft_loss = tf.add(soft_loss0, soft_loss1)
    '''
    soft_logits = tf.nn.softmax(logits/FLAGS.temperature)
    soft_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits/FLAGS.temperature,
                                                                        labels = soft_annotation,
                                                                        name = "entropy_soft"))
    hard_loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.squeeze(hard_annotation, squeeze_dims=[3]),
                                                                  name="entropy_hard")))

    tf.summary.scalar("entropy", soft_loss)

    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(soft_loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()
    
    #Check if has the log file
    if not os.path.exists(FLAGS.logs_dir):
        print("The logs path '%s' is not found" % FLAGS.logs_dir)
        print("Create now..")
        os.makedirs(FLAGS.logs_dir)
        print("%s is created successfully!" % FLAGS.logs_dir)

    #Create a file to write logs.
    #filename='logs'+ FLAGS.mode + str(datatime.datatime.now()) + '.txt'
    filename="logs_%s%s.txt"%(FLAGS.mode,datetime.datetime.now())
    path_=os.path.join(FLAGS.logs_dir,filename)
    with open(path_,'w') as logs_file:
        logs_file.write("The logs file is created at %s.\n" % datetime.datetime.now())
        logs_file.write("The model is ---%s---.\n" % FLAGS.logs_dir)
        logs_file.write("The mode is %s\n"% (FLAGS.mode))
        logs_file.write("The train data batch size is %d and the validation batch size is %d\n."%(FLAGS.batch_size,FLAGS.v_batch_size))
        logs_file.write("The train data is %s.\n" % (FLAGS.data_dir))
        logs_file.write("The data size is %d and the MAX_ITERATION is %d.\n" % (IMAGE_SIZE, MAX_ITERATION))
        logs_file.write("Setting up image reader...")

    
    print("Setting up image reader...")
    train_records, valid_records = scene_parsing.my_read_dataset(FLAGS.data_dir)
    print('number of train_records',len(train_records))
    print('number of valid_records',len(valid_records))
    with open(path_, 'a') as logs_file:
        logs_file.write('number of train_records %d\n' % len(train_records))
        logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset_soft.BatchDatset(train_records, image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    num_=0
    for itr in xrange(MAX_ITERATION):
        
        train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
        train_annotations_n = train_annotations/FLAGS.normal
        feed_dict = {image: train_images, soft_annotation: train_annotations_n, keep_probability: 0.85}

        sess.run(train_op, feed_dict=feed_dict)
        logs_file = open(path_, 'a')  
        if itr % 10 == 0:
            '''
            num_ = 0
            for ii in range(224):
                for jj in range(224):
                    #if train_annotations_n[0][ii][jj][1] > train_annotations_n[0][ii][jj][0]:
                    if train_annotations_n[0][ii][jj][1] > 0.85:
                        #print('anno',ii,', ', jj,': ', train_annotations_n[0][ii][jj])
                        num_ = num_+1
            print("the number of dim1>0.85: %d" % num_)
            '''
            soft_train_logits, train_logits = sess.run([soft_logits, logits], feed_dict=feed_dict)
            train_logits = np.array(train_logits)
            soft_train_logits = np.array(soft_train_logits)
            '''
            num_ = 0
            for ii in range(224):
                for jj in range(224):
                    if train_annotations_n[0][ii][jj][1] > 0.85:
                    #if soft_train_logits[0][ii][jj][1] > soft_train_logits[0][ii][jj][0]:
                    #if soft_train_logits[0][ii][jj][1] > 0.82:
                        #print('logtis',ii,', ', jj,': ', train_logits[0][ii][jj])
                        print('soft_logtis',ii,', ', jj,': ', soft_train_logits[0][ii][jj])
                        print('anno       ',ii,', ', jj,': ', train_annotations_n[0][ii][jj])
                        print('--------------------')
                    if soft_train_logits[0][ii][jj][1] > 0.82:
                        num_ = num_+1
            print("the number of dim1>0.82: %d" % num_)
            '''

            train_loss, summary_str = sess.run([soft_loss, summary_op], feed_dict=feed_dict)
            print("Step: %d, Train_loss:%g" % (itr, train_loss))
            logs_file.write("Step: %d, Train_loss:%g\n" % (itr, train_loss))
            summary_writer.add_summary(summary_str, itr)

        if itr % 500 == 0:
            
            #Caculate the accurary at the training set.
            train_random_images, train_random_annotations = train_dataset_reader.get_random_batch_for_train(FLAGS.v_batch_size)
            train_logits,train_loss,train_pred_anno = sess.run([soft_logits,soft_loss,pred_annotation], feed_dict={image:train_random_images,
                                                                                        soft_annotation:train_random_annotations/FLAGS.normal,
                                                                                        keep_probability:1.0})
            #accu_iou_,accu_pixel_ = accu.caculate_accurary(train_pred_anno, train_random_annotations/100)
            print("%s ---> Training_loss: %g" % (datetime.datetime.now(), train_loss))
            #print("%s ---> Training_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel_))
            #print("%s ---> Training_iou_accuary: %g" % (datetime.datetime.now(),accu_iou_))
            print("---------------------------")
            #Output the logs.
            num_ = num_ + 1
            logs_file.write("No.%d the itr number is %d.\n" % (num_, itr))
            logs_file.write("%s ---> Training_loss: %g.\n" % (datetime.datetime.now(), train_loss))
            #logs_file.write("%s ---> Training_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel_))
            #logs_file.write("%s ---> Training_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou_))
            logs_file.write("---------------------------\n")

            valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.v_batch_size)
            valid_loss,pred_anno=sess.run([hard_loss,pred_annotation],feed_dict={image:valid_images,
                                                                                      hard_annotation:valid_annotations,
                                                                                      keep_probability:1.0})
            accu_iou,accu_pixel=accu.caculate_accurary(pred_anno,valid_annotations)
            print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
            print("%s ---> Validation_pixel_accuary: %g" % (datetime.datetime.now(),accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" % (datetime.datetime.now(),accu_iou))

            #Output the logs.
            logs_file.write("%s ---> Validation_loss: %g.\n" % (datetime.datetime.now(), valid_loss))
            logs_file.write("%s ---> Validation_pixel_accuary: %g.\n" % (datetime.datetime.now(),accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g.\n" % (datetime.datetime.now(),accu_iou))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
            #End the iterator
        logs_file.close()