Ejemplo n.º 1
0
def classify_image(model_file, image_file, image_quantization=None):
    """Runs image classification and returns result with the highest score.

  Args:
    model_file: string, model file name.
    image_file: string, image file name.
    image_quantization: (scale: float, zero_point: float), assumed image
      quantization parameters.

  Returns:
    Classification result with the highest score as (index, score) tuple.
  """
    interpreter = make_interpreter(test_data_path(model_file))
    interpreter.allocate_tensors()
    image = test_image(image_file, common.input_size(interpreter))

    input_type = common.input_details(interpreter, 'dtype')
    if np.issubdtype(input_type, np.floating):
        # This preprocessing is specific to MobileNet V1 with floating point input.
        image = (input_type(image) - 127.5) / 127.5

    if np.issubdtype(input_type, np.integer) and image_quantization:
        image = rescale_image(
            image, image_quantization,
            common.input_details(interpreter, 'quantization'), input_type)

    common.set_input(interpreter, image)
    interpreter.invoke()
    return classify.get_classes(interpreter)[0]
Ejemplo n.º 2
0
def predict():
    data = {"success": False}
    if flask.request.method == "POST":
        if flask.request.files.get("image"):
            image_file = flask.request.files["image"]
            image = Image.open(image_file).convert('RGB').resize(
                HOLDER['size'], Image.ANTIALIAS)
            params = common.input_details(HOLDER['interpreter'],
                                          'quantization_parameters')
            scale = params['scales']
            zero_point = params['zero_points']
            mean = 128.0
            std = 128.0
            if abs(scale * std - 1) < 1e-5 and abs(mean - zero_point) < 1e-5:
                # Input data does not require preprocessing.
                common.set_input(HOLDER['interpreter'], image)
            else:
                # Input data requires preprocessing
                normalized_input = (np.asarray(image) -
                                    mean) / (std * scale) + zero_point
                np.clip(normalized_input, 0, 255, out=normalized_input)
                common.set_input(HOLDER['interpreter'],
                                 normalized_input.astype(np.uint8))

            start = time.perf_counter()
            HOLDER['interpreter'].invoke()
            inference_time = time.perf_counter() - start
            classes = classify.get_classes(HOLDER['interpreter'],
                                           HOLDER['top_k'], 0.0)

            if classes:
                data["success"] = True
                data["inference-time"] = '%.2f ms' % (inference_time * 1000)
                preds = []
                for c in classes:
                    preds.append({
                        "score": float(c.score),
                        "label": HOLDER['labels'].get(c.id, c.id)
                    })
                data["predictions"] = preds
    return flask.jsonify(data)
Ejemplo n.º 3
0
def init(args):
    global HOLDER
    HOLDER['model'] = args.model

    labels_file = args.models_directory + args.labels
    labels = read_label_file(labels_file) if args.labels else {}

    model_file = args.models_directory + args.model
    interpreter = make_interpreter(model_file)
    interpreter.allocate_tensors()

    print("\n Loaded engine with model : {}".format(model_file))

    # Model must be uint8 quantized
    if common.input_details(interpreter, 'dtype') != np.uint8:
        raise ValueError('Only support uint8 input type.')
    size = common.input_size(interpreter)

    HOLDER['labels'] = labels
    HOLDER['interpreter'] = interpreter
    HOLDER['size'] = size
    HOLDER['top_k'] = args.top_k
Ejemplo n.º 4
0
def main():
    global mot_tracker
    default_model_dir = '../models'
    default_model = 'mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite'
    default_labels = 'coco_labels.txt'
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help='.tflite model path',
                        default=os.path.join(default_model_dir,default_model))
    parser.add_argument('--labels', help='label file path',
                        default=os.path.join(default_model_dir, default_labels))
    parser.add_argument('--top_k', type=int, default=3,
                        help='number of categories with highest score to display')
    parser.add_argument('--camera_idx', type=int, help='Index of which video source to use. ', default = 0)
    parser.add_argument('--threshold', type=float, default=0.1,
                        help='classifier score threshold')
    parser.add_argument('--tracker', help='Name of the Object Tracker To be used.',
                        default=None,
                        choices=[None, 'sort'])
    parser.add_argument('--videosrc', help='Directly connected (dev) or Networked (net) video source. ', choices=['dev','net','file'],
                        default='dev')
    parser.add_argument('--display', help='Is a display attached',
                        default='False',
                        choices=['True', 'False'])
    parser.add_argument('--netsrc', help="Networked video source, example format: rtsp://192.168.1.43/mpeg4/media.amp",)
    parser.add_argument('--filesrc', help="Video file source. The videos subdirectory gets mapped into the Docker container, so place your files there.",)
    parser.add_argument('--modelInt8', help="Model expects input tensors to be Int8, not UInt8", default='False', choices=['True', 'False'])
    
    args = parser.parse_args()
    
    trackerName=args.tracker
    ''' Check for the object tracker.'''
    if trackerName != None:
        if trackerName == 'mediapipe':
            if detectCoralDevBoard():
                objectOfTracker = ObjectTracker('mediapipe')
            else:
                print("Tracker MediaPipe is only available on the Dev Board. Keeping the tracker as None")
                trackerName = None
        else:
            objectOfTracker = ObjectTracker(trackerName)
    else:
        pass
    
    if trackerName != None and objectOfTracker:
        mot_tracker = objectOfTracker.trackerObject.mot_tracker
    else:
        mot_tracker = None
    print('Loading {} with {} labels.'.format(args.model, args.labels))
    interpreter = make_interpreter(args.model)
    interpreter.allocate_tensors()
    labels = read_label_file(args.labels)
    inference_size = input_size(interpreter)
    if args.modelInt8=='True':
        model_int8 = True
    else:
        model_int8 = False

    if args.videosrc=='dev': 
        cap = cv2.VideoCapture(args.camera_idx)
    elif args.videosrc=='file':
        cap = cv2.VideoCapture(args.filesrc)    
    else:
        if args.netsrc==None:
            print("--videosrc was set to net but --netsrc was not specified")
            sys.exit()
        cap = cv2.VideoCapture(args.netsrc)    
        
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 0)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: 
            if args.videosrc=='file':
                cap = cv2.VideoCapture(args.filesrc)
                continue  
            else:
                break
        cv2_im = frame

        cv2_im_rgb = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
        cv2_im_rgb = cv2.resize(cv2_im_rgb, inference_size)

        if model_int8:
            im_pil = Image.fromarray(cv2_im_rgb)
            input_type = common.input_details(interpreter, 'dtype')
            img = (input_type(cv2_im_rgb)- 127.5) / 128.0
        
            run_inference(interpreter, img.flatten())
        else:
            run_inference(interpreter, cv2_im_rgb.tobytes())

        objs = get_objects(interpreter, args.threshold)[:args.top_k]
        height, width, channels = cv2_im.shape
        scale_x, scale_y = width / inference_size[0], height / inference_size[1]
        detections = []  # np.array([])
        for obj in objs:
            bbox = obj.bbox.scale(scale_x, scale_y)
            element = []  # np.array([])
            element.append(bbox.xmin)
            element.append(bbox.ymin)
            element.append(bbox.xmax)
            element.append(bbox.ymax)
            element.append(obj.score)  # print('element= ',element)
            element.append(obj.id)
            detections.append(element)  # print('dets: ',dets)
        # convert to numpy array #      print('npdets: ',dets)
        detections = np.array(detections)
        trdata = []
        trackerFlag = False
        if detections.any():
            if mot_tracker != None:
                trdata = mot_tracker.update(detections)
                trackerFlag = True

        cv2_im = append_objs_to_img(cv2_im,  detections, labels, trdata, trackerFlag)
        
        if args.display == 'True':
            cv2.imshow('frame', cv2_im)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
Ejemplo n.º 5
0
def main():
    global mot_tracker
    global mqtt_bridge
    global mqtt_topic

    camera_width = 1280
    camera_height = 720

    default_model_dir = '../models'
    default_model = 'mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite'
    default_labels = 'coco_labels.txt'
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        help='.tflite model path',
                        default=os.path.join(default_model_dir, default_model))
    parser.add_argument('--labels',
                        help='label file path',
                        default=os.path.join(default_model_dir,
                                             default_labels))
    parser.add_argument(
        '--top_k',
        type=int,
        default=3,
        help='number of categories with highest score to display')
    parser.add_argument('--camera_idx',
                        type=int,
                        help='Index of which video source to use. ',
                        default=0)
    parser.add_argument('--threshold',
                        type=float,
                        default=0.1,
                        help='classifier score threshold')
    parser.add_argument('--tracker',
                        help='Name of the Object Tracker To be used.',
                        default=None,
                        choices=[None, 'sort'])
    parser.add_argument(
        '--videosrc',
        help='Directly connected (dev) or Networked (net) video source. ',
        choices=['dev', 'net', 'file'],
        default='dev')
    parser.add_argument('--display',
                        help='Is a display attached',
                        default='False',
                        choices=['True', 'False'])
    parser.add_argument(
        '--netsrc',
        help=
        "Networked video source, example format: rtsp://192.168.1.43/mpeg4/media.amp",
    )
    parser.add_argument(
        '--filesrc',
        help=
        "Video file source. The videos subdirectory gets mapped into the Docker container, so place your files there.",
    )
    parser.add_argument(
        '--modelInt8',
        help="Model expects input tensors to be Int8, not UInt8",
        default='False',
        choices=['True', 'False'])
    parser.add_argument('--mqtt-host',
                        help="MQTT broker hostname",
                        default='127.0.0.1')
    parser.add_argument('--mqtt-port',
                        type=int,
                        help="MQTT broker port number (default 1883)",
                        default=1883)
    parser.add_argument('--mqtt-topic',
                        dest='mqtt_topic',
                        help="MQTT Object Tracking topic",
                        default="skyscan/object/json")

    args = parser.parse_args()

    trackerName = args.tracker
    ''' Check for the object tracker.'''
    if trackerName != None:
        if trackerName == 'mediapipe':
            if detectCoralDevBoard():
                objectOfTracker = ObjectTracker('mediapipe')
            else:
                print(
                    "Tracker MediaPipe is only available on the Dev Board. Keeping the tracker as None"
                )
                trackerName = None
        else:
            objectOfTracker = ObjectTracker(trackerName)
    else:
        pass

    if trackerName != None and objectOfTracker:
        mot_tracker = objectOfTracker.trackerObject.mot_tracker
    else:
        mot_tracker = None
    mqtt_topic = args.mqtt_topic
    mqtt_bridge = mqtt_wrapper.bridge(host=args.mqtt_host,
                                      port=args.mqtt_port,
                                      client_id="skyscan-object-tracker-%s" %
                                      (ID))
    mqtt_bridge.publish("skyscan/registration",
                        "skyscan-adsb-mqtt-" + ID + " Registration", 0, False)

    print('Loading {} with {} labels.'.format(args.model, args.labels))
    interpreter = make_interpreter(args.model)
    interpreter.allocate_tensors()
    labels = read_label_file(args.labels)
    inference_size = input_size(interpreter)
    if args.modelInt8 == 'True':
        model_int8 = True
    else:
        model_int8 = False

    if args.videosrc == 'dev':
        cap = cv2.VideoCapture(args.camera_idx)
    elif args.videosrc == 'file':
        cap = cv2.VideoCapture(args.filesrc)
    else:
        if args.netsrc == None:
            print("--videosrc was set to net but --netsrc was not specified")
            sys.exit()
        cap = cv2.VideoCapture(args.netsrc)

    cap.set(cv2.CAP_PROP_BUFFERSIZE, 0)
    timeHeartbeat = 0
    while cap.isOpened():
        if timeHeartbeat < time.mktime(time.gmtime()):
            timeHeartbeat = time.mktime(time.gmtime()) + 10
            mqtt_bridge.publish("skyscan/heartbeat",
                                "skyscan-object-tracker-" + ID + " Heartbeat",
                                0, False)
        start_time = time.monotonic()
        ret, frame = cap.read()
        if not ret:
            if args.videosrc == 'file':
                cap = cv2.VideoCapture(args.filesrc)
                continue
            else:
                break
        cv2_im = frame

        cv2_im_rgb = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
        cv2_im_rgb = cv2.resize(cv2_im_rgb, inference_size)

        if model_int8:
            im_pil = Image.fromarray(cv2_im_rgb)
            input_type = common.input_details(interpreter, 'dtype')
            img = (input_type(cv2_im_rgb) - 127.5) / 128.0

            run_inference(interpreter, img.flatten())
        else:
            run_inference(interpreter, cv2_im_rgb.tobytes())

        objs = get_objects(interpreter, args.threshold)[:args.top_k]
        height, width, channels = cv2_im.shape
        scale_x, scale_y = width / inference_size[0], height / inference_size[1]
        detections = []  # np.array([])
        for obj in objs:
            bbox = obj.bbox.scale(scale_x, scale_y)
            element = []  # np.array([])
            element.append(bbox.xmin)
            element.append(bbox.ymin)
            element.append(bbox.xmax)
            element.append(bbox.ymax)
            element.append(obj.score)  # print('element= ',element)
            element.append(obj.id)
            detections.append(element)  # print('dets: ',dets)
        # convert to numpy array #      print('npdets: ',dets)
        detections = np.array(detections)
        trdata = []
        trackerFlag = False
        if detections.any():
            if mot_tracker != None:
                trdata = mot_tracker.update(detections)
                trackerFlag = True

        cv2_im = append_objs_to_img(cv2_im, detections, labels, trdata,
                                    trackerFlag)
        follow_x, follow_y = object_to_follow(detections, labels, trdata,
                                              trackerFlag)
        if args.display == 'True':
            cv2.imshow('frame', cv2_im)

        if follow_x != None:
            follow_x = int(follow_x * (camera_height / height))
            follow_y = int(follow_y * (camera_width / width))
            coordinates = motionControl(follow_x, follow_y)
            follow = {"x": coordinates[0], "y": coordinates[1]}
            follow_json = json.dumps(follow)
            end_time = time.monotonic()
            print("x: {} y:{} new_x: {} new_y: {} Inference: {:.2f} ms".format(
                follow_x, follow_y, coordinates[0], coordinates[1],
                (end_time - start_time) * 1000))
            mqtt_bridge.publish(mqtt_topic, follow_json, 0, False)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-m',
                        '--model',
                        required=True,
                        help='File path of .tflite file.')
    parser.add_argument('-i',
                        '--input',
                        required=True,
                        help='Image to be classified.')
    parser.add_argument('-l', '--labels', help='File path of labels file.')
    parser.add_argument('-k',
                        '--top_k',
                        type=int,
                        default=1,
                        help='Max number of classification results')
    parser.add_argument('-t',
                        '--threshold',
                        type=float,
                        default=0.0,
                        help='Classification score threshold')
    parser.add_argument('-c',
                        '--count',
                        type=int,
                        default=5,
                        help='Number of times to run inference')
    parser.add_argument('-a',
                        '--input_mean',
                        type=float,
                        default=128.0,
                        help='Mean value for input normalization')
    parser.add_argument('-s',
                        '--input_std',
                        type=float,
                        default=128.0,
                        help='STD value for input normalization')
    args = parser.parse_args()

    labels = read_label_file(args.labels) if args.labels else {}

    interpreter = make_interpreter(*args.model.split('@'))
    interpreter.allocate_tensors()

    # Model must be uint8 quantized
    if common.input_details(interpreter, 'dtype') != np.uint8:
        raise ValueError('Only support uint8 input type.')

    size = common.input_size(interpreter)
    image = Image.open(args.input).convert('RGB').resize(size, Image.ANTIALIAS)

    # Image data must go through two transforms before running inference:
    # 1. normalization: f = (input - mean) / std
    # 2. quantization: q = f / scale + zero_point
    # The following code combines the two steps as such:
    # q = (input - mean) / (std * scale) + zero_point
    # However, if std * scale equals 1, and mean - zero_point equals 0, the input
    # does not need any preprocessing (but in practice, even if the results are
    # very close to 1 and 0, it is probably okay to skip preprocessing for better
    # efficiency; we use 1e-5 below instead of absolute zero).
    params = common.input_details(interpreter, 'quantization_parameters')
    scale = params['scales']
    zero_point = params['zero_points']
    mean = args.input_mean
    std = args.input_std
    if abs(scale * std - 1) < 1e-5 and abs(mean - zero_point) < 1e-5:
        # Input data does not require preprocessing.
        common.set_input(interpreter, image)
    else:
        # Input data requires preprocessing
        normalized_input = (np.asarray(image) - mean) / (std *
                                                         scale) + zero_point
        np.clip(normalized_input, 0, 255, out=normalized_input)
        common.set_input(interpreter, normalized_input.astype(np.uint8))

    # Run inference
    print('----INFERENCE TIME----')
    print('Note: The first inference on Edge TPU is slow because it includes',
          'loading the model into Edge TPU memory.')
    for _ in range(args.count):
        start = time.perf_counter()
        interpreter.invoke()
        inference_time = time.perf_counter() - start
        classes = classify.get_classes(interpreter, args.top_k, args.threshold)
        print('%.1fms' % (inference_time * 1000))

    print('-------RESULTS--------')
    for c in classes:
        print('%s: %.5f' % (labels.get(c.id, c.id), c.score))