예제 #1
0
def get_output(interpreter, score_threshold):
    """ Returns list of detected objects.

    Args:
        interpreter
        score_threshold

    Returns: bounding_box, class_id, score
    """
    # Get all output details
    boxes = get_output_tensor(interpreter, 0)
    class_ids = get_output_tensor(interpreter, 1)
    scores = get_output_tensor(interpreter, 2)
    count = int(get_output_tensor(interpreter, 3))

    results = []
    for i in range(count):
        if scores[i] >= score_threshold:
            result = {
                "bounding_box": boxes[i],
                "class_id": class_ids[i],
                "score": scores[i],
            }
            results.append(result)
    return results
예제 #2
0
def get_output(interpreter, top_k=1, score_threshold=0.0):
    """ Returns list of detected objects.
    """
    scores = get_output_tensor(interpreter, 0)
    classes = [
        Class(i, scores[i]) for i in np.argpartition(scores, -top_k)[-top_k:]
        if scores[i] >= score_threshold
    ]
    return sorted(classes, key=operator.itemgetter(1), reverse=True)
def get_output(interpreter, score_threshold, is_keypoints=False):
    """Returns list of detected objects.

    Args:
        interpreter
        score_threshold
        is_keypoints

    Returns: bounding_box, class_id, score
    """
    # Get all output details
    boxes = get_output_tensor(interpreter, 0)
    class_ids = get_output_tensor(interpreter, 1)
    scores = get_output_tensor(interpreter, 2)
    count = int(get_output_tensor(interpreter, 3))
    keypoints = None
    keypoints_scores = None
    if is_keypoints:
        keypoints = get_output_tensor(interpreter, 4)
        keypoints_scores = get_output_tensor(interpreter, 5)

    results = []
    for i in range(count):
        if scores[i] >= score_threshold:
            result = {
                "bounding_box": boxes[i],
                "class_id": class_ids[i],
                "box_score": scores[i],
            }
            if is_keypoints:
                keypoint_result = {
                    "keypoints": keypoints[i],
                    "keypoints_score": keypoints_scores[i],
                }
                result.update(keypoint_result)
            results.append(result)
    return results
def get_output(interpreter):
    """ Returns list of detected objects.
    """
    # Get all output details
    seg_map = get_output_tensor(interpreter, 0)
    return seg_map
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        help="File path of Tflite model.",
                        required=True)
    parser.add_argument("--threshold",
                        help="threshold to filter results.",
                        default=0.5,
                        type=float)
    parser.add_argument("--width",
                        help="Resolution width.",
                        default=640,
                        type=int)
    parser.add_argument("--height",
                        help="Resolution height.",
                        default=480,
                        type=int)
    parser.add_argument("--thread", help="Num threads.", default=2, type=int)
    args = parser.parse_args()

    # Initialize window.
    cv2.namedWindow(
        WINDOW_NAME,
        cv2.WINDOW_GUI_NORMAL | cv2.WINDOW_AUTOSIZE | cv2.WINDOW_KEEPRATIO)
    cv2.moveWindow(WINDOW_NAME, 100, 200)

    # Initialize TF-Lite interpreter.
    interpreter = make_interpreter(args.model, args.thread)
    interpreter.allocate_tensors()
    _, height, width, channel = interpreter.get_input_details()[0]["shape"]
    print("Interpreter: ", height, width, channel)

    model_name = os.path.splitext(os.path.basename(args.model))[0]

    elapsed_list = []

    resolution_width = args.width
    rezolution_height = args.height
    with picamera.PiCamera() as camera:

        camera.resolution = (resolution_width, rezolution_height)
        camera.framerate = 30
        # _, width, height, channels = engine.get_input_tensor_shape()
        rawCapture = PiRGBArray(camera)

        # allow the camera to warmup
        time.sleep(0.1)

        try:
            for frame in camera.capture_continuous(rawCapture,
                                                   format="rgb",
                                                   use_video_port=True):
                rawCapture.truncate(0)

                start = time.perf_counter()

                image = frame.array
                im = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                resize_im = cv2.resize(im, (width, height))
                input_im = resize_im.astype(np.float32)
                input_im = input_im / 255

                # Run inference.
                set_input_tensor(interpreter, input_im[np.newaxis, :, :])
                interpreter.invoke()
                predictions = get_output_tensor(interpreter, 0)

                pred_mask = create_mask(predictions)
                pred_mask = np.array(pred_mask, dtype="uint8")
                pred_mask = pred_mask * 127
                pred_mask = cv2.resize(pred_mask,
                                       (resolution_width, rezolution_height))

                inference_time = (time.perf_counter() - start) * 1000

                # Calc fps.
                elapsed_list.append(inference_time)
                avg_text = ""
                if len(elapsed_list) > 100:
                    elapsed_list.pop(0)
                    avg_elapsed_ms = np.mean(elapsed_list)
                    avg_text = " AGV: {0:.2f}ms".format(avg_elapsed_ms)

                # Display fps
                fps_text = "Inference: {0:.2f}ms".format(inference_time)
                display_text = model_name + " " + fps_text + avg_text
                visual.draw_caption(im, (10, 30), display_text)

                # display
                pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2BGR)
                display = cv2.hconcat([im, pred_mask])
                cv2.imshow(WINDOW_NAME, display)
                if cv2.waitKey(10) & 0xFF == ord("q"):
                    break

        finally:
            camera.stop_preview()

    # When everything done, release the window
    cv2.destroyAllWindows()