def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = json.load(f)
    utils.load_plugins()

    runs_dir = 'RUNS'

    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    logging.info("Evaluating on Validation data.")
    logdir = os.path.join(runs_dir, FLAGS.RUN)
    # logging.info("Output images will be saved to {}".format)
    ana.do_analyze(logdir)

    logging.info("Creating output on test data.")
    kitti_test.do_inference(logdir)

    logging.info("Analysis for pretrained model complete.")
    logging.info("For evaluating your own models I recommend using:"
                 "`tv-analyze --logdir /path/to/run`.")
    logging.info("tv-analysis has a much cleaner interface.")
예제 #2
0
def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = json.load(f)
    utils.load_plugins()

    if 'TV_DIR_RUNS' in os.environ:
        runs_dir = os.path.join(os.environ['TV_DIR_RUNS'],
                                'KittiSeg')
    else:
        runs_dir = 'RUNS'

    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    train.maybe_download_and_extract(hypes)

    maybe_download_and_extract(runs_dir)
    logging.info("Evaluating on Validation data.")
    logdir = os.path.join(runs_dir, FLAGS.RUN)
    # logging.info("Output images will be saved to {}".format)
    ana.do_analyze(logdir)

    logging.info("Creating output on test data.")
    kitti_test.do_inference(logdir)

    logging.info("Analysis for pretrained model complete.")
    logging.info("For evaluating your own models I recommend using:"
                 "`tv-analyze --logdir /path/to/run`.")
    logging.info("tv-analysis has a much cleaner interface.")
예제 #3
0
def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    with open(tf.app.flags.FLAGS.hypes, 'r') as f:
        logging.info("f: %s", f)
        hypes = json.load(f)
    utils.load_plugins()

    if 'TV_DIR_RUNS' in os.environ:
        runs_dir = os.path.join(os.environ['TV_DIR_RUNS'], 'KittiBox')
    else:
        runs_dir = 'RUNS'

    utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)

    utils._add_paths_to_sys(hypes)

    train.maybe_download_and_extract(hypes)

    maybe_download_and_extract(runs_dir)
    logging.info("Evaluating on Validation data.")
    logdir = os.path.join(runs_dir, FLAGS.RUN)
    # logging.info("Output images will be saved to {}".format)
    ana.do_analyze(logdir, base_path='hypes')

    logging.info("Analysis for pretrained model complete.")
    logging.info("For evaluating your own models I recommend using:"
                 "`tv-analyze --logdir /path/to/run`.")
    logging.info("")
    logging.info(
        "Output images can be found in {}/analyse/images.".format(logdir))
예제 #4
0
def main(_):
    utils.set_gpus_to_use()

    try:
        import tensorvision.train
        import tensorflow_fcn.utils
    except ImportError:
        logging.error("Could not import the submodules.")
        logging.error("Please execute:"
                      "'git submodule update --init --recursive'")
        exit(1)

    hypes_path = FLAGS.logdir
    hypes_path = os.path.join(hypes_path, "model_files/hypes.json")

    with open(hypes_path, 'r') as f:
        logging.info("f: %s", f)
        hypes = json.load(f)

    utils.load_plugins()

    if 'TV_DIR_RUNS' in os.environ:
        runs_dir = os.path.join(os.environ['TV_DIR_RUNS'], 'FacadeSeg')
    else:
        runs_dir = 'RUNS'

    utils.set_dirs(hypes, FLAGS.hypes)
    utils._add_paths_to_sys(hypes)

    logging.info("Evaluating on Validation data.")
    ana.do_analyze(FLAGS.logdir)

    logging.info("Segmenting and test data. Creating output.")
    ana.do_inference(FLAGS.logdir)

    logging.info("Analysis for pretrained model complete.")
예제 #5
0
def do_inference(hypes, modules, logdir):
    """
    Analyze a trained model.

    This will load model files and weights found in logdir and run a basic
    analysis.

    Paramters
    ---------
    logdir : string
        folder with logs
    """
    data_input, arch, objective, solver = modules

    data_dir = hypes['dirs']['data_dir']
    if 'TV_DIR_DATA' in os.environ:
        data_dir = os.environ['TV_DIR_DATA']
        hypes['dirs']['data_dir'] = data_dir
        hypes['dirs']['output_dir'] = logdir

    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Graph().as_default():

        image_pl, label_pl = _create_input_placeholder()

        image = tf.expand_dims(image_pl, 0)

        if 'whitening' not in hypes['arch'] or \
                hypes['arch']['whitening']:
            image = tf.image.per_image_whitening(image)
            logging.info('Whitening is enabled.')
        else:
            logging.info('Whitening is disabled.')

        # build the graph based on the loaded modules
        softmax = build_inference_graph(hypes,
                                        modules,
                                        image=image,
                                        label=label_pl)

        # prepaire the tv session
        sess_coll = core.start_tv_session(hypes)
        sess, saver, summary_op, summary_writer, coord, threads = sess_coll

        _load_weights(logdir, sess, saver)

    _prepare_output_folder(hypes, logdir)

    val_json = os.path.join(hypes['dirs']['eval_out'], 'val.json')

    if FLAGS.inspect:
        if not os.path.exists(val_json):
            logging.error("File does not exist: %s", val_json)
            logging.error("Please run kitti_eval in normal mode first.")
            exit(1)
        else:
            with open(val_json, 'r') as f:
                eval_dict = json.load(f)
                logging.debug(eval_dict)
                from IPython import embed
                embed()
                exit(0)

    logging.info("Doing evaluation with Validation Data")
    val_file = os.path.join(hypes['dirs']['data_dir'],
                            hypes['data']['val_file'])
    eval_dict = eval_dataset(hypes, val_file, True, sess, image_pl, softmax)

    with open(val_json, 'w') as outfile:
        # json.dump(eval_dict, outfile, indent=2)
        logging.info("Dumping currently not supported")

    logging.info("Succesfully evaluated Dataset. Output is written to %s",
                 val_json)

    logging_file = os.path.join(hypes['dirs']['eval_out'], 'eval.log')
    filewriter = _get_filewrite_handler(logging_file)
    rootlog = logging.getLogger('')
    rootlog.addHandler(filewriter)

    logging.info('Statistics on Validation Data.')

    logging.info('MaxF1          : %4.2f', 100 * eval_dict['MaxF'])
    logging.info('BestThresh     : %4.2f', 100 * eval_dict['BestThresh'])
    logging.info('Avg Precision  : %4.2f', 100 * eval_dict['AvgPrec'])
    logging.info('')
    ind5 = np.where(eval_dict['thresh'] >= 0.5)[0][0]
    logging.info('Precision @ 0.5: %4.2f', 100 * eval_dict['precision'][ind5])
    logging.info('Recall    @ 0.5: %4.2f', 100 * eval_dict['recall'][ind5])
    logging.info('TPR       @ 0.5: %4.2f', 100 * eval_dict['recall'][ind5])
    logging.info('TNR       @ 0.5: %4.2f', 100 * eval_dict['TNR'][ind5])

    if FLAGS.kitti_eval:
        do_kitti_eval_with_training_data(hypes, sess, image_pl, softmax)

    rootlog.removeHandler(filewriter)

    ana.do_analyze(FLAGS.logdir)