Exemplo n.º 1
0
def infer_on_stream(args, client):
    """
    Initialize the inference network, stream video to network,
    and output stats and video.
    
    . Load Model
    . Capture input stream from either camera, video or Image
    . Run Async inference per frame.
    . Calculate Stats and send image and stats to MQTT or FFMPEG server
    
    Parameters:
        args: Command line arguments parsed by `build_argparser()`.
        client: connected MQTT client
        threshold (float): The minimum threshold for detections.
    
    Returns:
        None
    """
    # Initialise the class
    infer_network = Network()
    # Set Probability threshold for detections
    prob_threshold = args.prob_threshold

    ### TODO: Load the model through `infer_network` ###
    infer_network.exec_network = infer_network.load_model\
        (args.model, args.device, args.cpu_extension)
    # extract information about model input layer
    (b, c, input_height, input_width) = infer_network.get_input_shape()

    ### TODO: Handle the input stream ###
    # extenstion of input file
    input_extension = os.path.splitext(args.input)[1].lower()
    supported_vid_exts = ['.mp4', '.mpeg', '.avi', '.mkv']
    supported_img_exts = [".bmp",".dib", ".jpeg", ".jp2", ".jpg", ".jpe",\
        ".png", ".pbm", ".pgm", ".ppm", ".sr", ".ras", ".tiff", ".tif"]
    single_image_mode = False
    # if input is camera
    if args.input.upper() == 'CAM':
        capture = cv2.VideoCapture(0)

    # if input is video
    elif input_extension in supported_vid_exts:
        capture = cv2.VideoCapture(args.input)

    # if input is image
    elif input_extension in supported_img_exts:
        single_image_mode = True
        capture = cv2.VideoCapture(args.input)
        capture.open(args.input)
    else:
        sys.exit("FATAL ERROR : The format of your input file is not supported" \
                     "\nsupported extensions are : " + ", ".join(supported_exts))
    prev_count = 0
    total_persons = 0
    ### TODO: Loop until stream is over ###
    while (capture.isOpened()):
        ### TODO: Read from the video capture ###
        ret, frame = capture.read()
        if not ret:
            break
        ### TODO: Pre-process the image as needed ###
        image = preprocessing(frame, input_width, input_height)
        ### TODO: Start asynchronous inference for specified request ###
        start_time = time.time()
        # run inference
        infer_network.exec_net(image)
        ### TODO: Wait for the result ###
        if infer_network.wait() == 0:
            infer_time = time.time() - start_time
            ### TODO: Get the results of the inference request ###
            outputs = infer_network.get_output()[0][0]
            ### Take model output and extract number of detections with confidence exceeding threshold
            ### and draw bounding boxes around detections
            out_image, current_count = apply_threshold(outputs, frame,
                                                       prob_threshold)

            # show inference time on image
            cv2.putText(out_image, "inference time: {:.5f} ms".format(infer_time), (30, 30),\
                        cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 255, 0), 1)

            ### TODO: Extract any desired stats from the results ###
            # when any person exit
            if current_count < prev_count:
                ### Topic "person/duration": key of "duration" ###
                # send duration to mqtt server client
                client.publish("person/duration",
                               json.dumps({"duration": time.time() - p_start}))

            # when new person enters
            if current_count > prev_count:
                total_persons += current_count - prev_count
                p_start = time.time()

            prev_count = current_count

            ### TODO: Calculate and send relevant information on ###
            ### current_count, total_count and duration to the MQTT server ###
            ### Topic "person": keys of "count" and "total" ###
            client.publish(
                "person",
                json.dumps({
                    "count": current_count,
                    "total": total_persons
                }))
        ### TODO: Send the frame to the FFMPEG server ###
        sys.stdout.buffer.write(out_image)
        sys.stdout.buffer.flush()
        ### TODO: Write an output image if `single_image_mode` ###
        if single_image_mode:
            cv2.imwrite("output_frame.png", out_image)
    # release resources
    capture.release()
    cv2.destroyAllWindows()
    client.disconnect()
    del infer_network
Exemplo n.º 2
0
def infer_on_stream(args, client):
    """
    Initialize the inference network, stream video to network,
    and output stats and video.
    :param args: Command line arguments parsed by `build_argparser()`
    :param client: MQTT client
    :return: None
    """
    global streaming_enabled
    # Initialise the class
    infer_network = Network()

    total_unique_persons = []
    use_reidentification = False
    # Set Probability threshold for detections
    if not args.prob_threshold is None:
        prob_threshold = args.prob_threshold
    else:
        prob_threshold = 0.2

    cur_request_id = 0
    last_detection_time = None
    duration = 0

    start = None

    single_image_mode = False
    show_info = False
    if args.show_info:
        show_info = args.show_info
    message = None
    if args.message:
        message = args.message

    if args.input == 'CAM':
        input_stream = 0
    elif args.input.endswith('.jpg') or args.input.endswith(
            '.bmp') or args.input.endswith('.png'):
        single_image_mode = True
        input_stream = args.input
    # Checks for video file
    else:
        input_stream = args.input
        assert os.path.isfile(args.input), "Specified input file doesn't exist"

    ### TODO: Load the model through `infer_network` ###
    n, c, h, w = infer_network.load_model(args.model, args.device, 1, 1,
                                          cur_request_id,
                                          args.cpu_extension)[1]

    # Intialize class for reidentification
    networkReIdentification = None
    identification_input_shape = None

    if args.reident_model:
        networkReIdentification = Network()
        networkReIdentification.load_model(args.reident_model, args.device, 1,
                                           1, cur_request_id,
                                           args.cpu_extension)
        identification_input_shape = networkReIdentification.get_input_shape()
        use_reidentification = True

    ### TODO: Handle the input stream ###
    if not single_image_mode:
        cap = cv2.VideoCapture(input_stream)
        if input_stream:
            cap.open(args.input)
        if not cap.isOpened():
            log.error("ERROR! Unable to open video source")

        detection_frame_count = 0
        total_frame_count = 0
        previous_detection_time = None
        last_person_counts = []
        average_person_count = 0
        detection_time = None

        total_seconds_elapsed_for_detection = 0

        # Parameters for duration
        max_len = 40
        track_threshold = 0.2
        track = deque(maxlen=max_len)

        ### TODO: Loop until stream is over ###
        while cap.isOpened():
            ### TODO: Read from the video capture ###
            flag, frame = cap.read()
            if not flag:
                break

            ### TODO: Pre-process the image as needed ###
            image = preprocessing(frame, h, w)

            ### TODO: Start asynchronous inference for specified request ###
            inf_start = time.time()

            infer_network.exec_network(cur_request_id, image)
            ### TODO: Wait for the result ###
            output_img = frame
            if infer_network.wait(cur_request_id) == 0:
                ### TODO: Get the results of the inference request ###
                det_time = time.time() - inf_start
                result = infer_network.get_output(cur_request_id)

                ### TODO: Extract any desired stats from the results ###

                image_h, image_w, _ = frame.shape
                num_detections = 0
                for box in result[0][0]:
                    label = box[1]
                    conf = box[2]

                    if label == 1:
                        if (conf > prob_threshold):
                            x_min = int(box[3] * image_w)
                            y_min = int(box[4] * image_h)
                            x_max = int(box[5] * image_w)
                            y_max = int(box[6] * image_h)
                            dist = (y_max - y_min) / (y_min + y_max)
                            color = (0, dist * 255, 255 - dist * 255)
                            if use_reidentification:
                                try:
                                    if conf > 0.85:
                                        crop_person = frame[y_min:y_max,
                                                            x_min:x_max]

                                        total_unique_persons = reidentification(
                                            cur_request_id,
                                            networkReIdentification,
                                            crop_person,
                                            identification_input_shape,
                                            total_unique_persons, conf)
                                except Exception as err:
                                    pass

                            cv2.rectangle(frame, (x_min, y_min),
                                          (x_max, y_max), color, int(1))
                            num_detections += 1
                            last_detection_time = datetime.now()
                            if start is None:
                                start = time.time()
                                time.clock()
                    else:
                        label_box_pos = None
                    if last_detection_time is not None:
                        second_diff = (datetime.now() -
                                       last_detection_time).total_seconds()
                        if second_diff >= 1.5:
                            if start is not None and num_detections == 0:
                                elapsed = time.time() - start
                                client.publish(
                                    "person/duration",
                                    json.dumps({"duration": elapsed}))
                                start = None
                                last_detection_time = None

                person_counts = num_detections

                overlay = output_img.copy()
                if show_info:
                    cv2.putText(overlay, message, (10, 40), FONT, 1,
                                (250, 250, 250), 2, cv2.LINE_AA)
                    cv2.putText(overlay,
                                'Person[s] found: {}'.format(person_counts),
                                (10, overlay.shape[0] - 40), FONT, 1,
                                (255, 255, 255), 1, cv2.LINE_AA)
                    cv2.putText(
                        overlay,
                        str(datetime.now().strftime(
                            "%A, %d. %B %Y %I:%M:%S %p")),
                        (10, overlay.shape[0] - 20), FONT, 1, (250, 250, 250),
                        1, cv2.LINE_AA)
                    cv2.addWeighted(overlay, ALPHA, output_img, 1 - ALPHA, 0,
                                    output_img)

            if len(last_person_counts) > 10:
                last_person_counts = last_person_counts[
                    1:len(last_person_counts) - 1]
            last_person_counts.append(person_counts)

            average_person_count = int(
                sum(last_person_counts) / len(last_person_counts))

            client.publish(
                "person",
                json.dumps({
                    "count": str(person_counts),
                    "total": len(total_unique_persons)
                }))

            ### TODO: Send the frame to the FFMPEG server ###
            if streaming_enabled:
                sys.stdout.buffer.write(output_img)
                sys.stdout.flush()
                pass
        if cap:
            cap.release()
            cv2.destroyAllWindows()
            client.disconnect()
            infer_network.dispose()

    ### TODO: Write an output image if `single_image_mode` ###
    elif single_image_mode:
        frame = cv2.imread(input_stream)
        image = preprocessing(frame, h, w)
        infer_network.exec_network(0, image)
        if infer_network.wait(0) == 0:
            result = infer_network.get_output(0)
            output_img, person_counts = get_draw_boxes_on_image(
                result, frame, prob_threshold, True)
            cv2.imwrite('output_image.jpg', output_img)
Exemplo n.º 3
0
def infer_on_stream(args, client):
    """
    Initialize the inference network, stream video to network,
    and output stats and video.

    :param args: Command line arguments parsed by `build_argparser()`
    :param client: MQTT client
    :return: None
    """
    # Initialise the class
    infer_network = Network()
    # Set Probability threshold for detections
    #prob_threshold = args.prob_threshold
    single_image = False
    ### TODO: Load the model through `infer_network` ###
    infer_network.load_model(args.model, args.device, CPU_EXTENSION)
    ### TODO: Handle the input stream ###
    net_input_shape = infer_network.get_input_shape()

    #Check for CAM, image or video
    if args.input == 'CAM':
        input_stream = 0
    elif args.input.endswith('.jpg') or args.input.endswith('.bmp'):
        single_image = True
        input_stream = args.input
    else:
        input_stream = args.input
        if not os.path.isfile(args.input):
            log.error("Specified input file doesn't exist")
            sys.exit(1)

    cap = cv2.VideoCapture(input_stream)
    if input_stream:
        cap.open(args.input)

    if not cap.isOpened():
        log.error("Unable to open source")
    width = int(cap.get(3))
    height = int(cap.get(4))
    out = cv2.VideoWriter('out.mp4', 0x00000021, 10, (width, height))
    global incident_flag, quantity, timesnap, timer, ticks, pt
    pt = args.prob_threshold
    incident_flag = False
    quantity = 0
    total = 0
    timesnap = 0
    timer = False
    ticks = 0
    curr_count = 0
    doneCounter = False
    start_time = 0
    ### TODO: Loop until stream is over ###
    while cap.isOpened():
        ### TODO: Read from the video capture ###
        timesnap += 1
        flag, original_frame = cap.read()
        if not flag:
            break
        if cv2.waitKey(25) & 0xFF == ord('q'):
            break
        ### TODO: Pre-process the image as needed ###
        frame = cv2.resize(original_frame,
                           (net_input_shape[3], net_input_shape[2]))
        frame = frame.transpose((2, 0, 1))
        frame = frame.reshape(1, *frame.shape)
        ### TODO: Start asynchronous inference for specified request ###
        infer_network.exec_network(frame)
        ### TODO: Wait for the result ###
        inf_start = time.time()
        if infer_network.wait() == 0:
            det_time = time.time() - inf_start
            ### TODO: Get the results of the inference request ###
            result = infer_network.get_output()
            ### TODO: Extract any desired stats from the results ###
            out_frame = draw_boundingBox(result, original_frame, height, width)

            inf_time_message = "Inference time: {:.3f}ms".format(det_time *
                                                                 1000)
            cv2.putText(out_frame, inf_time_message, (15, 15),
                        cv2.FONT_HERSHEY_COMPLEX, 0.5, (200, 10, 10), 1)
            ### TODO: Calculate and send relevant information on ###
            out.write(out_frame)
            cv2.imshow('frame', out_frame)
            curr_count = detect_person(result, curr_count)
            if incident_flag and not doneCounter:
                start_time = time.time()
                total += 1
                print("Total: {}".format(total))
                doneCounter = True
                json.dumps({"total": total})
                #client.publish("person",json.dumps({"total":total}))
            if not incident_flag and doneCounter and total >= 1:
                doneCounter = False
                duration = int(time.time() - start_time)
        print("Count: {}".format(curr_count))
        # Publish messages to the MQTT server
        #client.publish("person/duration",
        #json.dumps({"duration": duration}))
        # client.publish("person",json.dumps({"count":curr_count}))
        ### current_count, total_count and duration to the MQTT server ###
        ### Topic "person": keys of "count" and "total" ###
        ### Topic "person/duration": key of "duration" ###

        #sys.stdout.buffer.write(out_frame)
        #sys.stdout.flush()
        ### TODO: Send the frame to the FFMPEG server ###
        ### TODO: Write an output image if `single_image_mode` ###
        if single_image:
            cv2.imwrite('out_image.jpg', out_frame)
    out.release()
    cap.release()
    cv2.destroyAllWindows()