Esempio n. 1
0
def main():
    """ Main function. """

    log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
    args = build_argparser().parse_args()

    img_retrieval = ImageRetrieval(args.model, args.device, args.gallery, INPUT_SIZE,
                                   args.cpu_extension)

    frames = RoiDetectorOnVideo(args.i)

    compute_embeddings_times = []
    search_in_gallery_times = []

    positions = []

    for image, view_frame in frames:
        position = None
        sorted_indexes = []

        if image is not None:
            image = central_crop(image, divide_by=5, shift=1)

            elapsed, probe_embedding = time_elapsed(img_retrieval.compute_embedding, image)
            compute_embeddings_times.append(elapsed)

            elapsed, (sorted_indexes, distances) = time_elapsed(img_retrieval.search_in_gallery,
                                                                probe_embedding)
            search_in_gallery_times.append(elapsed)

            sorted_classes = [img_retrieval.gallery_classes[i] for i in sorted_indexes]

            if args.ground_truth is not None:
                position = sorted_classes.index(
                    img_retrieval.text_label_to_class_id[args.ground_truth])
                positions.append(position)
                log.info("ROI detected, found: %d, postion of target: %d",
                         sorted_classes[0], position)
            else:
                log.info("ROI detected, found: %s", sorted_classes[0])

        key = visualize(view_frame, position,
                        [img_retrieval.impaths[i] for i in sorted_indexes],
                        distances[sorted_indexes] if position is not None else None,
                        img_retrieval.input_size, np.mean(compute_embeddings_times),
                        np.mean(search_in_gallery_times), imshow_delay=3)

        if key == 27:
            break

    if positions:
        compute_metrics(positions)
def main():
    log.basicConfig(format='[ %(levelname)s ] %(message)s',
                    level=log.INFO,
                    stream=sys.stdout)
    args = build_argparser().parse_args()

    img_retrieval = ImageRetrieval(args.model, args.device, args.gallery,
                                   INPUT_SIZE, args.cpu_extension)

    cap = open_images_capture(args.input, args.loop)
    if cap.get_type() not in ('VIDEO', 'CAMERA'):
        raise RuntimeError(
            "The input should be a video file or a numeric camera ID")
    frames = RoiDetectorOnVideo(cap)

    compute_embeddings_times = []
    search_in_gallery_times = []

    positions = []

    frames_processed = 0
    presenter = monitors.Presenter(args.utilization_monitors, 0)
    video_writer = cv2.VideoWriter()

    for image, view_frame in frames:
        position = None
        sorted_indexes = []

        if image is not None:
            image = central_crop(image, divide_by=5, shift=1)

            elapsed, probe_embedding = time_elapsed(
                img_retrieval.compute_embedding, image)
            compute_embeddings_times.append(elapsed)

            elapsed, (sorted_indexes, distances) = time_elapsed(
                img_retrieval.search_in_gallery, probe_embedding)
            search_in_gallery_times.append(elapsed)

            sorted_classes = [
                img_retrieval.gallery_classes[i] for i in sorted_indexes
            ]

            if args.ground_truth is not None:
                position = sorted_classes.index(
                    img_retrieval.text_label_to_class_id[args.ground_truth])
                positions.append(position)
                log.info("ROI detected, found: %d, position of target: %d",
                         sorted_classes[0], position)
            else:
                log.info("ROI detected, found: %s", sorted_classes[0])

        image, key = visualize(
            view_frame,
            position, [img_retrieval.impaths[i] for i in sorted_indexes],
            distances[sorted_indexes] if position is not None else None,
            img_retrieval.input_size,
            np.mean(compute_embeddings_times),
            np.mean(search_in_gallery_times),
            imshow_delay=3,
            presenter=presenter,
            no_show=args.no_show)

        if frames_processed == 0:
            if args.output and not video_writer.open(
                    args.output, cv2.VideoWriter_fourcc(*'MJPG'), cap.fps(),
                (image.shape[1], image.shape[0])):
                raise RuntimeError("Can't open video writer")
        frames_processed += 1
        if video_writer.isOpened() and (args.output_limit <= 0 or
                                        frames_processed <= args.output_limit):
            video_writer.write(image)

        if key == 27:
            break
    print(presenter.reportMeans())

    if positions:
        compute_metrics(positions)