예제 #1
0
    def test_perform_inference_without_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_retinanet_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions = mmdet_detection_model.original_predictions

        boxes = original_predictions

        # find box of first car detection with conf greater than 0.5
        for box in boxes[2]:
            print(len(box))
            if len(box) == 5:
                if box[4] > 0.5:
                    break

        # compare
        self.assertEqual(box[:4].astype("int").tolist(), [448, 309, 495, 341])
        self.assertEqual(len(boxes), 80)
예제 #2
0
    def test_get_prediction(self):
        from sahi.model import MmdetDetectionModel
        from sahi.predict import get_prediction
        from sahi.prediction import PredictionInput

        from tests.test_utils import (
            download_mmdet_cascade_mask_rcnn_model,
            mmdet_cascade_mask_rcnn_config_path,
            mmdet_cascade_mask_rcnn_model_path,
        )

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.3,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # get full sized prediction
        prediction_result = get_prediction(
            image=image,
            detection_model=mmdet_detection_model,
            shift_amount=[0, 0],
            full_image_size=None,
            merger=None,
            matcher=None,
        )
        object_prediction_list = prediction_result["object_prediction_list"]

        # compare
        self.assertEqual(len(object_prediction_list), 23)
        num_person = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "person":
                num_person += 1
        self.assertEqual(num_person, 0)
        num_truck = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "truck":
                num_truck += 1
        self.assertEqual(num_truck, 3)
        num_car = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "car":
                num_car += 1
        self.assertEqual(num_car, 20)
예제 #3
0
    def test_convert_original_predictions(self):
        from sahi.model import MmdetDetectionModel
        from sahi.prediction import PredictionInput

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.5,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list_w_category_mapping = (
            mmdet_detection_model.object_prediction_list)

        # compare
        self.assertEqual(len(object_prediction_list_w_category_mapping), 53)
        self.assertEqual(
            object_prediction_list_w_category_mapping[0].category.id, 0)
        self.assertEqual(
            object_prediction_list_w_category_mapping[0].category.name,
            "person")
        self.assertEqual(
            object_prediction_list_w_category_mapping[0].bbox.to_coco_bbox(),
            [337, 124, 8, 14],
        )
        self.assertEqual(
            object_prediction_list_w_category_mapping[1].category.id, 2)
        self.assertEqual(
            object_prediction_list_w_category_mapping[1].category.name, "car")
        self.assertEqual(
            object_prediction_list_w_category_mapping[1].bbox.to_coco_bbox(),
            [657, 204, 13, 10],
        )
        self.assertEqual(
            object_prediction_list_w_category_mapping[5].category.id, 2)
        self.assertEqual(
            object_prediction_list_w_category_mapping[5].category.name, "car")
        self.assertEqual(
            object_prediction_list_w_category_mapping[2].bbox.to_coco_bbox(),
            [760, 232, 20, 15],
        )
예제 #4
0
    def test_create_original_predictions_from_object_prediction_list_with_mask_output(
        self, ):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions_1 = mmdet_detection_model.original_predictions

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        original_predictions_2 = mmdet_detection_model._create_original_predictions_from_object_prediction_list(
            object_prediction_list)

        # compare
        self.assertEqual(len(original_predictions_1),
                         len(original_predictions_2))  # 2
        self.assertEqual(len(original_predictions_1[0]),
                         len(original_predictions_2[0]))  # 80
        self.assertEqual(len(original_predictions_1[0][2]),
                         len(original_predictions_2[0][2]))  # 25
        self.assertEqual(type(original_predictions_1[0]),
                         type(original_predictions_2[0]))  # list
        self.assertEqual(original_predictions_1[0][2].dtype,
                         original_predictions_2[0][2].dtype)  # float32
        self.assertEqual(original_predictions_1[0][0][0].dtype,
                         original_predictions_2[0][0][0].dtype)  # float32
        self.assertEqual(original_predictions_1[1][0][0].dtype,
                         original_predictions_2[1][0][0].dtype)  # bool
        self.assertEqual(len(original_predictions_1[0][0][0]),
                         len(original_predictions_2[0][0][0]))  # 5
        self.assertEqual(len(original_predictions_1[0][1]),
                         len(original_predictions_1[0][1]))  # 0
        self.assertEqual(original_predictions_1[0][1].shape,
                         original_predictions_1[0][1].shape)  # (0, 5)
예제 #5
0
    def test_create_original_predictions_from_object_prediction_list(self):
        from sahi.model import MmdetDetectionModel
        from sahi.prediction import PredictionInput

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.5,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions_1 = mmdet_detection_model.original_predictions

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        original_predictions_2 = mmdet_detection_model._create_original_predictions_from_object_prediction_list(
            object_prediction_list)

        # compare
        self.assertEqual(len(original_predictions_1),
                         len(original_predictions_2))  # 2
        self.assertEqual(len(original_predictions_1[0]),
                         len(original_predictions_2[0]))  # 80
        self.assertEqual(len(original_predictions_1[0][2]),
                         len(original_predictions_2[0][2]))  # 25
        self.assertEqual(type(original_predictions_1[0]),
                         type(original_predictions_2[0]))  # list
        self.assertEqual(original_predictions_1[0][2].dtype,
                         original_predictions_2[0][2].dtype)  # float32
        self.assertEqual(original_predictions_1[0][0][0].dtype,
                         original_predictions_2[0][0][0].dtype)  # float32
        self.assertEqual(original_predictions_1[1][0][0].dtype,
                         original_predictions_2[1][0][0].dtype)  # bool
        self.assertEqual(len(original_predictions_1[0][0][0]),
                         len(original_predictions_2[0][0][0]))  # 5
        self.assertEqual(len(original_predictions_1[0][1]),
                         len(original_predictions_1[0][1]))  # 0
        self.assertEqual(original_predictions_1[0][1].shape,
                         original_predictions_1[0][1].shape)  # (0, 5)
예제 #6
0
    def test_prediction_input(self):
        from sahi.prediction import PredictionInput

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # init prediction input
        prediction_input = PredictionInput(image_list=[image], )

        # compare
        self.assertEqual(len(prediction_input.shift_amount_list),
                         len(prediction_input.image_list))
예제 #7
0
    def test_get_prediction_mmdet(self):
        from sahi.model import MmdetDetectionModel
        from sahi.predict import get_prediction
        from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.3,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # get full sized prediction
        prediction_result = get_prediction(
            image=image,
            detection_model=mmdet_detection_model,
            shift_amount=[0, 0],
            full_shape=None,
        )
        object_prediction_list = prediction_result.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 19)
        num_person = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "person":
                num_person += 1
        self.assertEqual(num_person, 0)
        num_truck = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "truck":
                num_truck += 1
        self.assertEqual(num_truck, 0)
        num_car = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "car":
                num_car += 1
        self.assertEqual(num_car, 19)
예제 #8
0
    def test_convert_original_predictions_with_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 53)
        self.assertEqual(object_prediction_list[0].category.id, 0)
        self.assertEqual(object_prediction_list[0].category.name, "person")
        self.assertEqual(
            object_prediction_list[0].bbox.to_coco_bbox(),
            [337, 124, 8, 14],
        )
        self.assertEqual(object_prediction_list[1].category.id, 2)
        self.assertEqual(object_prediction_list[1].category.name, "car")
        self.assertEqual(
            object_prediction_list[1].bbox.to_coco_bbox(),
            [657, 204, 13, 10],
        )
        self.assertEqual(object_prediction_list[5].category.id, 2)
        self.assertEqual(object_prediction_list[5].category.name, "car")
        self.assertEqual(
            object_prediction_list[2].bbox.to_coco_bbox(),
            [760, 232, 20, 15],
        )
예제 #9
0
    def test_get_prediction_yolov5(self):
        from sahi.model import Yolov5DetectionModel
        from sahi.predict import get_prediction
        from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model

        # init model
        download_yolov5s6_model()

        yolov5_detection_model = Yolov5DetectionModel(
            model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
            confidence_threshold=0.3,
            device=None,
            category_remapping=None,
            load_at_init=False,
        )
        yolov5_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # get full sized prediction
        prediction_result = get_prediction(
            image=image, detection_model=yolov5_detection_model, shift_amount=[0, 0], full_shape=None, postprocess=None
        )
        object_prediction_list = prediction_result.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 12)
        num_person = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "person":
                num_person += 1
        self.assertEqual(num_person, 0)
        num_truck = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "truck":
                num_truck += 1
        self.assertEqual(num_truck, 0)
        num_car = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "car":
                num_car += 1
        self.assertEqual(num_car, 12)
예제 #10
0
    def test_convert_original_predictions(self):
        from sahi.model import Yolov5DetectionModel

        # init model
        download_yolov5s6_model()

        yolov5_detection_model = Yolov5DetectionModel(
            model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        yolov5_detection_model.perform_inference(image)

        # convert predictions to ObjectPrediction list
        yolov5_detection_model.convert_original_predictions()
        object_prediction_list = yolov5_detection_model.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 14)
        self.assertEqual(object_prediction_list[0].category.id, 2)
        self.assertEqual(object_prediction_list[0].category.name, "car")
        desired_bbox = [321, 322, 62, 40]
        predicted_bbox = object_prediction_list[0].bbox.to_coco_bbox()
        margin = 2
        for ind, point in enumerate(predicted_bbox):
            assert point < desired_bbox[ind] + margin and point > desired_bbox[
                ind] - margin
        self.assertEqual(object_prediction_list[5].category.id, 2)
        self.assertEqual(object_prediction_list[5].category.name, "car")
        self.assertEqual(
            object_prediction_list[5].bbox.to_coco_bbox(),
            [617, 195, 24, 23],
        )
예제 #11
0
    def test_convert_original_predictions_without_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_retinanet_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 100)
        self.assertEqual(object_prediction_list[0].category.id, 2)
        self.assertEqual(object_prediction_list[0].category.name, "car")
        self.assertEqual(
            object_prediction_list[0].bbox.to_coco_bbox(),
            [448, 309, 47, 32],
        )
        self.assertEqual(object_prediction_list[5].category.id, 2)
        self.assertEqual(object_prediction_list[5].category.name, "car")
        self.assertEqual(
            object_prediction_list[5].bbox.to_coco_bbox(),
            [523, 225, 22, 17],
        )
예제 #12
0
    def test_perform_inference(self):
        from sahi.model import Yolov5DetectionModel

        # init model
        download_yolov5s6_model()

        yolov5_detection_model = Yolov5DetectionModel(
            model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        yolov5_detection_model.perform_inference(image)
        original_predictions = yolov5_detection_model.original_predictions

        boxes = original_predictions.xyxy

        # find box of first car detection with conf greater than 0.5
        for box in boxes[0]:
            if box[5].item() == 2:  # if category car
                if box[4].item() > 0.5:
                    break

        # compare
        desired_bbox = [321, 322, 383, 362]
        predicted_bbox = list(map(int, box[:4].tolist()))
        margin = 2
        for ind, point in enumerate(predicted_bbox):
            assert point < desired_bbox[ind] + margin and point > desired_bbox[
                ind] - margin
        self.assertEqual(len(original_predictions.names), 80)
예제 #13
0
    def test_perform_inference(self):
        from sahi.model import MmdetDetectionModel
        from sahi.prediction import PredictionInput

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.5,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions = mmdet_detection_model.original_predictions

        boxes = original_predictions[0]
        masks = original_predictions[1]

        # find box of first person detection with conf greater than 0.5
        for box in boxes[0]:
            print(len(box))
            if len(box) == 5:
                if box[4] > 0.5:
                    break

        # compare
        self.assertEqual(box[:4].astype("int").tolist(), [336, 123, 346, 139])
        self.assertEqual(len(boxes), 80)
        self.assertEqual(len(masks), 80)
예제 #14
0
    def test_perform_inference_with_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions = mmdet_detection_model.original_predictions

        boxes = original_predictions[0]
        masks = original_predictions[1]

        # find box of first person detection with conf greater than 0.5
        for box in boxes[0]:
            print(len(box))
            if len(box) == 5:
                if box[4] > 0.5:
                    break

        # compare
        self.assertEqual(box[:4].astype("int").tolist(), [336, 123, 346, 139])
        self.assertEqual(len(boxes), 80)
        self.assertEqual(len(masks), 80)
예제 #15
0
    def test_slice_image(self):
        # read coco file
        coco_path = "tests/data/coco_utils/terrain1_coco.json"
        coco = Coco.from_coco_dict_or_path(coco_path)

        output_file_name = None
        output_dir = None
        image_path = "tests/data/coco_utils/" + coco.images[0].file_name
        slice_image_result = slice_image(
            image=image_path,
            coco_annotation_list=coco.images[0].annotations,
            output_file_name=output_file_name,
            output_dir=output_dir,
            slice_height=512,
            slice_width=512,
            overlap_height_ratio=0.1,
            overlap_width_ratio=0.4,
            min_area_ratio=0.1,
            out_ext=".png",
            verbose=False,
        )

        self.assertEqual(len(slice_image_result.images), 18)
        self.assertEqual(len(slice_image_result.coco_images), 18)
        self.assertEqual(slice_image_result.coco_images[0].annotations, [])
        self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296)
        self.assertEqual(
            slice_image_result.coco_images[15].annotations[1].bbox,
            [17, 186, 48, 152],
        )

        image_cv = read_image(image_path)
        slice_image_result = slice_image(
            image=image_cv,
            coco_annotation_list=coco.images[0].annotations,
            output_file_name=output_file_name,
            output_dir=output_dir,
            slice_height=512,
            slice_width=512,
            overlap_height_ratio=0.1,
            overlap_width_ratio=0.4,
            min_area_ratio=0.1,
            out_ext=".png",
            verbose=False,
        )

        self.assertEqual(len(slice_image_result.images), 18)
        self.assertEqual(len(slice_image_result.coco_images), 18)
        self.assertEqual(slice_image_result.coco_images[0].annotations, [])
        self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296)
        self.assertEqual(
            slice_image_result.coco_images[15].annotations[1].bbox,
            [17, 186, 48, 152],
        )

        image_pil = Image.open(image_path)
        slice_image_result = slice_image(
            image=image_pil,
            coco_annotation_list=coco.images[0].annotations,
            output_file_name=output_file_name,
            output_dir=output_dir,
            slice_height=512,
            slice_width=512,
            overlap_height_ratio=0.1,
            overlap_width_ratio=0.4,
            min_area_ratio=0.1,
            out_ext=".png",
            verbose=False,
        )

        self.assertEqual(len(slice_image_result.images), 18)
        self.assertEqual(len(slice_image_result.coco_images), 18)
        self.assertEqual(slice_image_result.coco_images[0].annotations, [])
        self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296)
        self.assertEqual(
            slice_image_result.coco_images[15].annotations[1].bbox,
            [17, 186, 48, 152],
        )
예제 #16
0
def get_prediction(
    image,
    detection_model,
    shift_amount: list = [0, 0],
    full_image_size=None,
    merger=None,
    matcher=None,
):
    """
    Function for performing prediction for given image using given detection_model.

    Arguments:
        image: str or np.ndarray
            Location of image or numpy image matrix to slice
        detection_model: model.DetectionMode
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
        full_image_size: List
            Size of the full image, should be in the form of [height, width]
        merger: postprocess.PredictionMerger
        matcher: postprocess.PredictionMatcher

    Returns:
        A dict with fields:
            object_prediction_list: a list of ObjectPrediction
            durations_in_seconds: a dict containing elapsed times for profiling
    """
    durations_in_seconds = dict()

    # read image if image is str
    if isinstance(image, str):
        image = read_image(image)
    # get prediction
    time_start = time.time()
    detection_model.perform_inference(image)
    time_end = time.time() - time_start
    durations_in_seconds["prediction"] = time_end

    # process prediction
    time_start = time.time()
    # works only with 1 batch
    detection_model.convert_original_predictions(
        shift_amount=shift_amount,
        full_image_size=full_image_size,
    )
    object_prediction_list = detection_model.object_prediction_list
    filtered_object_prediction_list = [
        object_prediction for object_prediction in object_prediction_list
        if object_prediction.score.score >
        detection_model.prediction_score_threshold
    ]
    # merge matching predictions
    if merger is not None:
        filtered_object_prediction_list = merger.merge_batch(
            matcher,
            filtered_object_prediction_list,
            merge_type="merge",
        )

    time_end = time.time() - time_start
    durations_in_seconds["postprocess"] = time_end

    return {
        "object_prediction_list": filtered_object_prediction_list,
        "durations_in_seconds": durations_in_seconds,
    }
예제 #17
0
def predict_folder(
    model_name="MmdetDetectionModel",
    model_parameters=None,
    image_dir=None,
    visual_output_dir=None,
    pickle_dir=None,
    crop_dir=None,
    apply_sliced_prediction: bool = True,
    slice_height: int = 256,
    slice_width: int = 256,
    overlap_height_ratio: float = 0.1,
    overlap_width_ratio: float = 0.2,
    visual_bbox_thickness: int = 1,
    visual_text_size: float = 1,
    visual_text_thickness: int = 1,
):
    """
    Performs prediction for all present images in given folder.

    Args:
        model_name: str
            Name of the implemented DetectionModel in model.py file.
        model_parameter: a dict with fields:
            model_path: str
                Path for the instance segmentation model weight
            config_path: str
                Path for the mmdetection instance segmentation model config file
            prediction_score_threshold: float
                All predictions with score < prediction_score_threshold will be discarded.
            device: str
                Torch device, "cpu" or "cuda"
            category_remapping: dict: str to int
                Remap category ids after performing inference
        image_dir: str
            Directory that contain images to be predicted.
        visual_output_dir: str
            Directory that prediction visuals are going to be exported. Set to None for no visuals.
        pickle_dir: str
            Directory that object prediction list pickles are going to be exported. Set to None for no pickles.
        crop_dir: str
            Directory that detected bounding boxes will be cropped&exported. Set to None for no crops.
        apply_sliced_prediction: bool
            Set to True if you want sliced prediction, set to False for full prediction.
        slice_height: int
            Height of each slice.  Defaults to ``256``.
        slice_width: int
            Width of each slice.  Defaults to ``256``.
        overlap_height_ratio: float
            Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window
            of size 256 yields an overlap of 51 pixels).
            Default to ``0.2``.
        overlap_width_ratio: float
            Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window
            of size 256 yields an overlap of 51 pixels).
            Default to ``0.2``.
        visual_bbox_thickness: int
        visual_text_size: float
        visual_text_thickness: int
    """
    image_path_list = list_files(directory=image_dir,
                                 contains=[".jpg", ".jpeg", ".png"])
    # init model instance
    DetectionModel = import_class(model_name)
    detection_model = DetectionModel(
        model_path=model_parameters["model_path"],
        config_path=model_parameters["config_path"],
        prediction_score_threshold=model_parameters[
            "prediction_score_threshold"],
        device=model_parameters["device"],
        category_mapping=model_parameters["category_mapping"],
        category_remapping=model_parameters["category_remapping"],
    )
    detection_model.load_model()

    # iterate over source images
    for image_path in tqdm(image_path_list):
        # get filename
        (
            filename_with_extension,
            filename_without_extension,
        ) = get_base_filename(path=image_path)
        # load image
        image = read_image(image_path)

        # perform prediction
        if apply_sliced_prediction:
            # get sliced prediction
            prediction_result = get_sliced_prediction(
                image=image,
                detection_model=detection_model,
                slice_height=slice_height,
                slice_width=slice_width,
                overlap_height_ratio=overlap_height_ratio,
                overlap_width_ratio=overlap_width_ratio,
            )
            object_prediction_list = prediction_result[
                "object_prediction_list"]
        else:
            # get full sized prediction
            prediction_result = get_prediction(
                image=image,
                detection_model=detection_model,
                shift_amount=[0, 0],
                full_image_size=None,
                merger=None,
                matcher=None,
            )
            object_prediction_list = prediction_result[
                "object_prediction_list"]

        # export prediction boxes
        if crop_dir:
            crop_object_predictions(
                image=image,
                object_prediction_list=object_prediction_list,
                output_dir=crop_dir,
                file_name=filename_without_extension,
            )
        # export prediction list as pickle
        if pickle_dir:
            save_path = os.path.join(
                pickle_dir,
                filename_without_extension + ".pickle",
            )
            save_pickle(data=object_prediction_list, save_path=save_path)
        # export visualization
        if visual_output_dir:
            visualize_object_predictions(
                image,
                object_prediction_list=object_prediction_list,
                rect_th=visual_bbox_thickness,
                text_size=visual_text_size,
                text_th=visual_text_thickness,
                output_dir=visual_output_dir,
                file_name=filename_without_extension,
            )
예제 #18
0
import matplotlib.pyplot as plt
import matplotlib
# import required functions, classes
from sahi.model import Yolov5DetectionModel
from sahi.utils.cv import read_image, visualize_object_predictions, ipython_display
from sahi.predict import get_prediction, get_sliced_prediction, predict

detection_model = Yolov5DetectionModel(
    model_path=
    '/data2/mritime/yolov5/yolov5/runs/train/6epochs_1024_smd_hous1_charls0/weights/best.pt',
    prediction_score_threshold=0.3,
    device="cuda",  # or 'cpu'
)

image_dir = "/data2/mritime/youtube_videos/houston_1080/fr_136290.png"
image = read_image(image_dir)

# result = get_prediction(image, detection_model)


def vis_pred():
    result = get_sliced_prediction(image,
                                   detection_model,
                                   slice_height=1024,
                                   slice_width=1024,
                                   overlap_height_ratio=0.2,
                                   overlap_width_ratio=0.2)
    visualization_result = visualize_object_predictions(
        image,
        object_prediction_list=result["object_prediction_list"],
        output_dir=None,