예제 #1
0
def video_main():
    with tf.device('/gpu:0'):
        train_records, valid_records = scene_parsing.my_read_video_dataset(cfgs.seq_list_path, cfgs.anno_path)
        print('The number of video records is %d.' %  len(valid_records))
        model = SeqFCNNet(cfgs.mode, cfgs.max_epochs, cfgs.batch_size, cfgs.NUM_OF_CLASSESS, train_records, valid_records, cfgs.IMAGE_SIZE, cfgs.init_lr, cfgs.keep_prob, cfgs.logs_dir)
        model.build_video()
        model.vis_video()
예제 #2
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=[None, IMAGE_SIZE, IMAGE_SIZE, 7],
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1],
                                name="annotation")

    pred_annotation_value, pred_annotation, logits, pred_prob = inference(
        image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=2)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=2)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=2)
    #logits:the last layer of conv net
    #labels:the ground truth
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    trainable_var = tf.trainable_variables()
    if cfgs.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    #Create a file to write logs.
    #filename='logs'+ cfgs.mode + str(datetime.datetime.now()) + '.txt'
    filename = "logs_%s%s.txt" % (cfgs.mode, datetime.datetime.now())
    path_ = os.path.join(cfgs.logs_dir, filename)
    logs_file = open(path_, 'w')
    logs_file.write("The logs file is created at %s\n" %
                    datetime.datetime.now())
    logs_file.write("The mode is %s\n" % (cfgs.mode))
    logs_file.write(
        "The train data batch size is %d and the validation batch size is %d.\n"
        % (cfgs.batch_size, cfgs.v_batch_size))
    logs_file.write("The train data is %s.\n" % (cfgs.data_dir))
    logs_file.write("The model is ---%s---.\n" % cfgs.logs_dir)

    print("Setting up image reader...")
    logs_file.write("Setting up image reader...\n")
    train_records, valid_records = scene_parsing.my_read_video_dataset(
        cfgs.seq_list_path, cfgs.anno_path)
    print('number of train_records', len(train_records))
    print('number of valid_records', len(valid_records))
    logs_file.write('number of train_records %d\n' % len(train_records))
    logs_file.write('number of valid_records %d\n' % len(valid_records))

    print("Setting up dataset reader")
    vis = True if cfgs.mode == 'all_visualize' else False
    image_options = {
        'resize': True,
        'resize_size': IMAGE_SIZE,
        'visualize': vis
    }

    if cfgs.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records,
                                                   image_options)
    validation_dataset_reader = dataset.BatchDatset(valid_records,
                                                    image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(cfgs.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(cfgs.logs_dir)
    #if not train,restore the model trained before
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if cfgs.mode == "accurary":
        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_annotations, valid_filenames, if_con, start, end = validation_dataset_reader.next_batch_valid(
                cfgs.v_batch_size)
            valid_loss, pred_anno = sess.run(
                [loss, pred_annotation],
                feed_dict={
                    image: valid_images,
                    annotation: valid_annotations,
                    keep_probability: 1.0
                })
            accu_iou, accu_pixel = accu.caculate_accurary(
                pred_anno, valid_annotations)
            print("Ture %d ---> the data from %d to %d" % (count, start, end))
            print("%s ---> Validation_pixel_accuary: %g" %
                  (datetime.datetime.now(), accu_pixel))
            print("%s ---> Validation_iou_accuary: %g" %
                  (datetime.datetime.now(), accu_iou))
            #Output logs.
            logs_file.write("Ture %d ---> the data from %d to %d\n" %
                            (count, start, end))
            logs_file.write("%s ---> Validation_pixel_accuary: %g\n" %
                            (datetime.datetime.now(), accu_pixel))
            logs_file.write("%s ---> Validation_iou_accuary: %g\n" %
                            (datetime.datetime.now(), accu_iou))

            accu_iou_t = accu_iou_t + accu_iou
            accu_pixel_t = accu_pixel_t + accu_pixel
        print("%s ---> Total validation_pixel_accuary: %g" %
              (datetime.datetime.now(), accu_pixel_t / count))
        print("%s ---> Total validation_iou_accuary: %g" %
              (datetime.datetime.now(), accu_iou_t / count))
        #Output logs
        logs_file.write("%s ---> Total validation_pixel_accurary: %g\n" %
                        (datetime.datetime.now(), accu_pixel_t / count))
        logs_file.write("%s ---> Total validation_iou_accurary: %g\n" %
                        (datetime.datetime.now(), accu_iou_t / count))

    elif cfgs.mode == "all_visualize":

        re_save_dir = "%s%s" % (cfgs.result_dir, datetime.datetime.now())
        logs_file.write("The result is save at file'%s'.\n" % re_save_dir)
        logs_file.write("The number of part visualization is %d.\n" %
                        cfgs.v_batch_size)

        #Check the result path if exists.
        if not os.path.exists(re_save_dir):
            print("The path '%s' is not found." % re_save_dir)
            print("Create now ...")
            os.makedirs(re_save_dir)
            print("Create '%s' successfully." % re_save_dir)
            logs_file.write("Create '%s' successfully.\n" % re_save_dir)

        re_save_dir_im = os.path.join(re_save_dir, 'images')
        re_save_dir_heat = os.path.join(re_save_dir, 'heatmap')
        re_save_dir_ellip = os.path.join(re_save_dir, 'ellip')
        re_save_dir_transheat = os.path.join(re_save_dir, 'transheat')
        if not os.path.exists(re_save_dir_im):
            os.makedirs(re_save_dir_im)
        if not os.path.exists(re_save_dir_heat):
            os.makedirs(re_save_dir_heat)
        if not os.path.exists(re_save_dir_ellip):
            os.makedirs(re_save_dir_ellip)
        if not os.path.exists(re_save_dir_transheat):
            os.makedirs(re_save_dir_transheat)

        count = 0
        if_con = True
        accu_iou_t = 0
        accu_pixel_t = 0

        while if_con:
            count = count + 1
            valid_images, valid_filename, valid_cur_images, if_con, start, end = validation_dataset_reader.next_batch_video_valid(
                cfgs.v_batch_size)
            pred_value, pred, logits_, pred_prob_ = sess.run(
                [pred_annotation_value, pred_annotation, logits, pred_prob],
                feed_dict={
                    image: valid_images,
                    keep_probability: 1.0
                })
            print("Turn %d :----start from %d ------- to %d" %
                  (count, start, end))
            pred = np.squeeze(pred, axis=3)
            pred_value = np.squeeze(pred_value, axis=3)

            for itr in range(len(pred)):
                filename = valid_filename[itr]['filename']
                valid_images_ = pred_visualize(valid_cur_images[itr].copy(),
                                               pred[itr])
                utils.save_image(valid_images_.astype(np.uint8),
                                 re_save_dir_im,
                                 name="inp_" + filename)

                if cfgs.fit_ellip:
                    #valid_images_ellip=fit_ellipse_findContours_ori(valid_images[itr].copy(),np.expand_dims(pred[itr],axis=2).astype(np.uint8))
                    valid_images_ellip = fit_ellipse_findContours(
                        valid_cur_images[itr].copy(),
                        np.expand_dims(pred[itr], axis=2).astype(np.uint8))
                    utils.save_image(valid_images_ellip.astype(np.uint8),
                                     re_save_dir_ellip,
                                     name="ellip_" + filename)
                if cfgs.heatmap:
                    heat_map = density_heatmap(pred_prob_[itr, :, :, 1])
                    utils.save_image(heat_map.astype(np.uint8),
                                     re_save_dir_heat,
                                     name="heat_" + filename)
                if cfgs.trans_heat:
                    trans_heat_map = translucent_heatmap(
                        valid_cur_images[itr],
                        heat_map.astype(np.uint8).copy())
                    utils.save_image(trans_heat_map,
                                     re_save_dir_transheat,
                                     name="trans_heat_" + filename)