Пример #1
0
def create_session_config(log_device_placement=False,
                          enable_graph_rewriter=False,
                          gpu_mem_fraction=0.95,
                          use_tpu=False,
                          xla_jit_level=tf.OptimizerOptions.OFF,
                          inter_op_parallelism_threads=0,
                          intra_op_parallelism_threads=0):
    """The TensorFlow Session config to use."""
    if use_tpu:
        graph_options = tf.GraphOptions()
    else:
        if enable_graph_rewriter:
            rewrite_options = rewriter_config_pb2.RewriterConfig()
            rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.ON
            graph_options = tf.GraphOptions(rewrite_options=rewrite_options)
        else:
            graph_options = tf.GraphOptions(
                optimizer_options=tf.OptimizerOptions(
                    opt_level=tf.OptimizerOptions.L1,
                    do_function_inlining=False,
                    global_jit_level=xla_jit_level))

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=gpu_mem_fraction)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        graph_options=graph_options,
        gpu_options=gpu_options,
        log_device_placement=log_device_placement,
        inter_op_parallelism_threads=inter_op_parallelism_threads,
        intra_op_parallelism_threads=intra_op_parallelism_threads,
        isolate_session_state=True)
    return config
Пример #2
0
def get_session(params, isolate_session_state=True):
    """Builds and returns a `tf.Session`."""
    config = tf.ConfigProto(
        isolate_session_state=isolate_session_state,
        allow_soft_placement=True,
        graph_options=tf.GraphOptions(optimizer_options=tf.OptimizerOptions(
            opt_level=tf.OptimizerOptions.L0,
            do_common_subexpression_elimination=False,
            do_function_inlining=False,
            do_constant_folding=False)))
    return tf.Session(target=params.master, config=config)
Пример #3
0
def main():
    # current camera frame
    global frame, annotatedFrame, frameQueue, currentFps, selectedIdx, selectedClassName, objectDistance, boxes, scores, stats
    global currentMode, M_AUTOMANEUVER, M_AUTONAV, M_MANUAL

    # print(cv2.getBuildInformation())
    print("Loading model")
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(CHKPT_PATH, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

    label_map = label_map_util.load_labelmap(LABELS_PATH)
    categories = label_map_util.convert_label_map_to_categories(
        label_map, max_num_classes=2, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)

    print("Starting main python module")
    if not DEBUG_DISABLE_FLIGHT:
        flightData = Drone(updateFlightInfo)
        process = Thread(target=flight.flightMain, args=(flightData, ))
        process.start()
    ip = '0.0.0.0'
    server = ThreadedHTTPServer((ip, 9090), CamHandler)
    target = Thread(target=server.serve_forever, args=())
    i = 0

    # To flip the image, modify the flip_method parameter (0 and 2 are the most common)
    #print(gstreamer_pipeline(flip_method=0))
    cap = cv2.VideoCapture(gstreamer_pipeline(flip_method=2),
                           cv2.CAP_GSTREAMER)
    fpsSmoothing = 70
    lastUpdate = time.time()
    try:
        if cap.isOpened():
            print("CSI Camera opened")
            graph_options = tf.GraphOptions(
                optimizer_options=tf.OptimizerOptions(
                    opt_level=tf.OptimizerOptions.L1, ))
            OptConfig = tf.ConfigProto(graph_options=graph_options)
            with detection_graph.as_default():
                with tf.Session(graph=detection_graph,
                                config=OptConfig) as sess:
                    # Definite input and output Tensors for detection_graph
                    image_tensor = detection_graph.get_tensor_by_name(
                        'image_tensor:0')
                    # Each box represents a part of the image where a particular object
                    # was detected.
                    detection_boxes = detection_graph.get_tensor_by_name(
                        'detection_boxes:0')
                    # Each score represent how level of confidence for each of the objects.
                    # Score is shown on the result image, together with the class
                    # label.
                    detection_scores = detection_graph.get_tensor_by_name(
                        'detection_scores:0')
                    detection_classes = detection_graph.get_tensor_by_name(
                        'detection_classes:0')
                    num_detections = detection_graph.get_tensor_by_name(
                        'num_detections:0')
                    i = 0
                    print("TensorFlow session loaded.")
                    while mainThreadRunning:
                        ret_val, img = cap.read()
                        frame = img
                        # convert OpenCV's BGR to RGB as the model
                        # was trained on RGB images
                        color_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        # resize image to model size of 360x270
                        color_frame = cv2.resize(color_frame, (360, 270),
                                                 interpolation=cv2.INTER_CUBIC)
                        image_np_expanded = np.expand_dims(color_frame, axis=0)
                        # Actual detection
                        (boxes, scores, classes, num) = sess.run(
                            [
                                detection_boxes, detection_scores,
                                detection_classes, num_detections
                            ],
                            feed_dict={image_tensor: image_np_expanded})

                        # Draw boxes using TF library, should be off during competition
                        if useBoxVisualization:
                            vis_util.visualize_boxes_and_labels_on_image_array(
                                frame,
                                np.squeeze(boxes),
                                np.squeeze(classes).astype(np.int32),
                                np.squeeze(scores),
                                category_index,
                                use_normalized_coordinates=True,
                                line_thickness=4,
                                min_score_thresh=MIN_CONFIDENCE)

                        # Now that we have the detected BBoxes, it's time to determine our current obstacle
                        # First, gather stats about the bounding boxes
                        # squeezing makes it so you can do access box[i] directly instead of having to
                        # access box[0][i]
                        boxes = np.squeeze(boxes)
                        classes = np.squeeze(classes)
                        scores = np.squeeze(scores)
                        stats = []
                        j = 0
                        # This is 15ft, any object farther than that is a misidentification
                        lowestDistance = 15

                        if DEBUG_DUMP_DETECTIONS:
                            print("Boxes // Classes // Scores")
                            print(boxes)
                            print(classes)
                            print(scores)
                        # Reset selections
                        selectedIdx = None
                        if len(boxes) > 0:
                            for j in range(0, len(boxes)):
                                if scores[j] >= MIN_CONFIDENCE:
                                    stats.insert(
                                        j, getBoxStats(boxes[j], classes[j]))
                                    # print("box[%d] distance is %f" % (j, stats[j]['distance']))
                                    if stats[j]['distance'] < lowestDistance:
                                        selectedIdx = j
                                        selectedClassName = classToString(
                                            classes[j])
                                        objectDistance = stats[j]['distance']
                                        lowestDistance = objectDistance
                                        #print("Selected box[%d]: distance %f class %s conf %f" % (j, objectDistance, selectedClassName, scores[j]))
                                else:
                                    # Skip calculations on this box if it does not meet
                                    # confidence threshold
                                    stats.insert(j, 0)
                        if not DEBUG_DISABLE_FLIGHT:
                            if selectedIdx is not None:
                                flightData.upData(stats[selectedIdx],
                                                  selectedClassName)
                            else:
                                flightData.upData(None, "None")

                        # add the HUD to the current image
                        annotatedFrame = applyHud()
                        # currentFrameTime = time.time()
                        #if frameQueue.full():
                        #    with frameQueue.mutex:
                        #        frameQueue.queue.clear()
                        frameQueue.put(annotatedFrame.copy())
                        if i == 0:
                            target.start()
                            print("Starting MJPEG stream")
                        i += 1
                        # FPS smoothing algorithm
                        frameTime = time.time() - lastUpdate
                        frameFps = 1 / frameTime
                        currentFps += (frameFps - currentFps) / fpsSmoothing
                        lastUpdate = time.time()

                    cap.release()
        else:
            print("FATAL: Unable to open camera")

    except KeyboardInterrupt:
        sys.exit()