示例#1
0
def main(_):
    tv_utils.set_gpus_to_use()

    if FLAGS.input_image is None:
        logging.error("No input_image was given.")
        logging.info(
            "Usage: python demo.py --input_image data/test.png "
            "[--output_image output_image] [--logdir /path/to/weights] "
            "[--gpus GPUs_to_use] ")
        exit(1)

    if FLAGS.logdir is None:
        # Download and use weights from the MultiNet Paper
        if 'TV_DIR_RUNS' in os.environ:
            runs_dir = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiSeg')
        else:
            runs_dir = 'RUNS'
        maybe_download_and_extract(runs_dir)
        logdir = os.path.join(runs_dir, default_run)
    else:
        logging.info("Using weights found in {}".format(FLAGS.logdir))
        logdir = FLAGS.logdir

    # Loading hyperparameters from logdir
    hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes')

    logging.info("Hypes loaded successfully.")

    # Loading tv modules (encoder.py, decoder.py, eval.py) from logdir
    modules = tv_utils.load_modules_from_logdir(logdir)
    logging.info("Modules loaded successfully. Starting to build tf graph.")

    # Create tf graph and build module.
    with tf.Graph().as_default():

        image_placehold = tf.placeholder(tf.float32, shape=[1, None, None, 3])

        # build Tensorflow graph using the model from logdir
        prediction = core.build_inference_graph(hypes,
                                                modules,
                                                image=image_placehold)

        logging.info("Graph build successfully.")

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        saver = tf.train.Saver()

        # Load weights from logdir
        core.load_weights(logdir, sess, saver)

        # eval_dict,_=kitti_eval.evaluate_without_crf(hypes,sess,image_placehold,prediction)
        eval_dict, _ = kitti_eval.evaluate_test(hypes, sess, image_placehold,
                                                prediction)
        for name, value in eval_dict:
            logging.info('    %s %s : % 0.04f ' % (name, '(raw)', value))
示例#2
0
def train_loop(myhypes=None):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    if tf.app.flags.FLAGS.hypes is None:
        logging.error("No hype file is given.")
        logging.info("Usage: python train.py --hypes hypes/KittiClass.json")
        exit(1)

    with open(myhypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = commentjson.load(f)
    utils.load_plugins()

    if tf.app.flags.FLAGS.mod is not None:
        import ast
        mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
        dict_merge(hypes, mod_dict)

    if 'TV_DIR_RUNS' in os.environ:
        os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
                                                 'KittiSeg')
    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    train.maybe_download_and_extract(hypes)
    logging.info("Initialize training folder")

    # TODO initialize the train folder and copy some arg files to it--------------------------------yu
    train.initialize_training_folder(hypes)
    logging.info("Start training")

    train.do_training(hypes,
                      trainable_scopes=FLAGS.trainable_scopes,
                      exclude_scopes=FLAGS.checkpoint_exclude_scopes,
                      checkpoint_path=FLAGS.checkpoint_path)
示例#3
0
def main(_):
    """Run main function."""
    if FLAGS.hypes is None:
        logging.error("No hypes are given.")
        logging.error("Usage: tv-train --hypes hypes.json")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = json.load(f)

    utils.set_gpus_to_use()
    utils.load_plugins()
    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    logging.info("Initialize training folder")
    initialize_training_folder(hypes)
    maybe_download_and_extract(hypes)
    logging.info("Start training")
    do_training(hypes)
示例#4
0
def main(_):
    tv_utils.set_gpus_to_use()

    if FLAGS.input_image is None:
        logging.error("No input_image was given.")
        logging.info(
            "Usage: python demo.py --input_image data/test.png "
            "[--output_image output_image] [--logdir /path/to/weights] "
            "[--gpus GPUs_to_use] ")
        exit(1)

    if FLAGS.logdir is None:
        # Download and use weights from the MultiNet Paper
        if 'TV_DIR_RUNS' in os.environ:
            runs_dir = os.path.join(os.environ['TV_DIR_RUNS'],
                                    'KittiSeg')
        else:
            runs_dir = 'RUNS'
        maybe_download_and_extract(runs_dir)
        logdir = os.path.join(runs_dir, default_run)
    else:
        logging.info("Using weights found in {}".format(FLAGS.logdir))
        logdir = FLAGS.logdir

    # Loading hyperparameters from logdir
    hypes = tv_utils.load_hypes_from_logdir(logdir, base_path='hypes')

    logging.info("Hypes loaded successfully.")

    # Loading tv modules (encoder.py, decoder.py, eval.py) from logdir
    modules = tv_utils.load_modules_from_logdir(logdir)
    logging.info("Modules loaded successfully. Starting to build tf graph.")

    # Create tf graph and build module.
    with tf.Graph().as_default():
        # Create placeholder for input
        # image_pl = tf.placeholder(tf.float32)
        # image = tf.expand_dims(image_pl, 0)
        image_placehold= tf.placeholder(tf.float32,shape=[1,None,None,3])

        # build Tensorflow graph using the model from logdir
        prediction = core.build_inference_graph(hypes, modules,
                                                image=image_placehold)

        logging.info("Graph build successfully.")

        # Create a session for running Ops on the Graph.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        saver = tf.train.Saver()

        # Load weights from logdir
        core.load_weights(logdir, sess, saver)

        logging.info("Weights loaded successfully.")
    start_time = time.time()
    logging.info('process start')
    input_image = FLAGS.input_image
    logging.info("Starting inference using {} as input".format(input_image))

    # Load and resize input image
    image = scp.misc.imread(input_image)
    image= scp.misc.imresize(image,[150,150,3])

    # image2 = scp.misc.imread('/home/yu/projects/KittiSeg/data/demo/demo2.png')
    if hypes['jitter']['reseize_image']:
        # Resize input only, if specified in hypes
        image_height = hypes['jitter']['image_height']
        image_width = hypes['jitter']['image_width']
        image = scp.misc.imresize(image, size=(image_height, image_width),
                                  interp='cubic')

    # Run KittiSeg model on image
    image_to_input=np.expand_dims(image,axis=0)
    feed = {image_placehold: image_to_input}
    softmax = prediction['softmax']

    output = sess.run([softmax], feed_dict=feed)
    logging.info('Finished in {}s\n'.format(
        time.time() - start_time))
    start_time = time.time()
    print ('the output.shape =',output[0].shape)


    # TODO test time -----------------------------------------------------------yu

    for i in xrange(100):

        output = sess.run([softmax],feed_dict=feed)
    logging.info('Finished 100 time and average is in {}s\n'.format((
        time.time() - start_time)/100.0))

    # TODO test time ----------------------------------------------------------yu
    # Reshape output from flat vector to 2D Image
    shape = image.shape
    print ('the image.shape =',shape)
    #TODO if use CRF-----------------------------------------------------------------------------------yu
    use_crf=False
    street_prediction=np.empty([int(shape[0])*int(shape[1]),1])
    if use_crf:
        output_image = output[0][:, :].reshape(shape[0], shape[1], 2)
        # output_image = post_crf.post_process_crf(image, output_image, 2)
    else:
        output_image = output[0][:, 1].reshape(shape[0], shape[1])

    # output_image = output[0][:, 1].reshape(shape[0], shape[1])
    print('the first output_image.shape =', output_image.shape)
    # output_image=output[0][:,:].reshape(shape[0],shape[1],2)
    print ('the second output_image.shape =',output_image.shape)
    # output_image = my_array[:,:,:,1].reshape(shape[0],shape[1])
    # print (output_image.shape)
    # print (output_image.shape)
    start_time=time.time()
    # Plot confidences as red-blue overlay
    # rb_image = seg.make_overlay(image, output_image)

    # res=post_crf.post_process_crf(image,output_image,2)
    output_image1=output[0][:,0].reshape(shape[0], shape[1])
    street_prediction=output_image > output_image1
    output_image2=output_image.reshape(shape[0],shape[1],-1)
    output_image2=output_image2*255.0



    # Accept all pixel with conf >= 0.5 as positive prediction
    # This creates a `hard` prediction result for class street
    threshold = 0.5

    street_prediction = output_image > threshold
    # street_prediction = res > threshold
    # Plot the hard prediction as green overlay
    green_image = tv_utils.fast_overlay(image, street_prediction)

    # Save output images to disk.
    if FLAGS.output_image is None:
        output_base_name = input_image
    else:
        output_base_name = FLAGS.output_image

    raw_image_name = output_base_name.split('.')[0] + '_raw.png'
    rb_image_name = output_base_name.split('.')[0] + '_rb.png'
    green_image_name = output_base_name.split('.')[0] + '_green.png'
    # scp.misc.imshow(rb_image)
    #TODO save the image--------------------------------------------------------------------------yu
    save_image=False
    if save_image:
        # np.save(FLAGS.logdir+'182000' + '.png',output_image)
        cv2.imwrite(FLAGS.logdir+'um_000031' + '.png', green_image)
        # scp.misc.imsave(FLAGS.logdir+'182000' + '.png', output_image2)
    scp.misc.imshow(green_image)
    scp.misc.imsave('prediction_umm_000036.png',green_image)