Пример #1
0
    def test_perform_inference_without_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_retinanet_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions = mmdet_detection_model.original_predictions

        boxes = original_predictions

        # find box of first car detection with conf greater than 0.5
        for box in boxes[2]:
            print(len(box))
            if len(box) == 5:
                if box[4] > 0.5:
                    break

        # compare
        self.assertEqual(box[:4].astype("int").tolist(), [448, 309, 495, 341])
        self.assertEqual(len(boxes), 80)
Пример #2
0
    def test_get_sliced_prediction(self):
        from sahi.model import MmdetDetectionModel
        from sahi.predict import get_sliced_prediction

        from tests.test_utils import (
            download_mmdet_cascade_mask_rcnn_model,
            mmdet_cascade_mask_rcnn_config_path,
            mmdet_cascade_mask_rcnn_model_path,
        )

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.3,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"

        slice_height = 512
        slice_width = 512
        overlap_height_ratio = 0.1
        overlap_width_ratio = 0.2

        # get sliced prediction
        prediction_result = get_sliced_prediction(
            image=image_path,
            detection_model=mmdet_detection_model,
            slice_height=slice_height,
            slice_width=slice_width,
            overlap_height_ratio=overlap_height_ratio,
            overlap_width_ratio=overlap_width_ratio,
        )
        object_prediction_list = prediction_result["object_prediction_list"]

        # compare
        self.assertEqual(len(object_prediction_list), 24)
        num_person = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "person":
                num_person += 1
        self.assertEqual(num_person, 0)
        num_truck = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "truck":
                num_truck += 2
        self.assertEqual(num_truck, 4)
        num_car = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "car":
                num_car += 1
        self.assertEqual(num_car, 22)
Пример #3
0
    def test_get_prediction(self):
        from sahi.model import MmdetDetectionModel
        from sahi.predict import get_prediction
        from sahi.prediction import PredictionInput

        from tests.test_utils import (
            download_mmdet_cascade_mask_rcnn_model,
            mmdet_cascade_mask_rcnn_config_path,
            mmdet_cascade_mask_rcnn_model_path,
        )

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.3,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # get full sized prediction
        prediction_result = get_prediction(
            image=image,
            detection_model=mmdet_detection_model,
            shift_amount=[0, 0],
            full_image_size=None,
            merger=None,
            matcher=None,
        )
        object_prediction_list = prediction_result["object_prediction_list"]

        # compare
        self.assertEqual(len(object_prediction_list), 23)
        num_person = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "person":
                num_person += 1
        self.assertEqual(num_person, 0)
        num_truck = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "truck":
                num_truck += 1
        self.assertEqual(num_truck, 3)
        num_car = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "car":
                num_car += 1
        self.assertEqual(num_car, 20)
Пример #4
0
    def test_load_model(self):
        from sahi.model import MmdetDetectionModel

        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.3,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()
        self.assertNotEqual(mmdet_detection_model.model, None)
Пример #5
0
    def test_get_prediction_mmdet(self):
        from sahi.model import MmdetDetectionModel
        from sahi.predict import get_prediction
        from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.3,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # get full sized prediction
        prediction_result = get_prediction(
            image=image,
            detection_model=mmdet_detection_model,
            shift_amount=[0, 0],
            full_shape=None,
        )
        object_prediction_list = prediction_result.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 19)
        num_person = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "person":
                num_person += 1
        self.assertEqual(num_person, 0)
        num_truck = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "truck":
                num_truck += 1
        self.assertEqual(num_truck, 0)
        num_car = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "car":
                num_car += 1
        self.assertEqual(num_car, 19)
Пример #6
0
    def test_create_original_predictions_from_object_prediction_list(self):
        from sahi.model import MmdetDetectionModel
        from sahi.prediction import PredictionInput

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.5,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions_1 = mmdet_detection_model.original_predictions

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        original_predictions_2 = mmdet_detection_model._create_original_predictions_from_object_prediction_list(
            object_prediction_list)

        # compare
        self.assertEqual(len(original_predictions_1),
                         len(original_predictions_2))  # 2
        self.assertEqual(len(original_predictions_1[0]),
                         len(original_predictions_2[0]))  # 80
        self.assertEqual(len(original_predictions_1[0][2]),
                         len(original_predictions_2[0][2]))  # 25
        self.assertEqual(type(original_predictions_1[0]),
                         type(original_predictions_2[0]))  # list
        self.assertEqual(original_predictions_1[0][2].dtype,
                         original_predictions_2[0][2].dtype)  # float32
        self.assertEqual(original_predictions_1[0][0][0].dtype,
                         original_predictions_2[0][0][0].dtype)  # float32
        self.assertEqual(original_predictions_1[1][0][0].dtype,
                         original_predictions_2[1][0][0].dtype)  # bool
        self.assertEqual(len(original_predictions_1[0][0][0]),
                         len(original_predictions_2[0][0][0]))  # 5
        self.assertEqual(len(original_predictions_1[0][1]),
                         len(original_predictions_1[0][1]))  # 0
        self.assertEqual(original_predictions_1[0][1].shape,
                         original_predictions_1[0][1].shape)  # (0, 5)
Пример #7
0
    def test_convert_original_predictions(self):
        from sahi.model import MmdetDetectionModel
        from sahi.prediction import PredictionInput

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.5,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list_w_category_mapping = (
            mmdet_detection_model.object_prediction_list)

        # compare
        self.assertEqual(len(object_prediction_list_w_category_mapping), 53)
        self.assertEqual(
            object_prediction_list_w_category_mapping[0].category.id, 0)
        self.assertEqual(
            object_prediction_list_w_category_mapping[0].category.name,
            "person")
        self.assertEqual(
            object_prediction_list_w_category_mapping[0].bbox.to_coco_bbox(),
            [337, 124, 8, 14],
        )
        self.assertEqual(
            object_prediction_list_w_category_mapping[1].category.id, 2)
        self.assertEqual(
            object_prediction_list_w_category_mapping[1].category.name, "car")
        self.assertEqual(
            object_prediction_list_w_category_mapping[1].bbox.to_coco_bbox(),
            [657, 204, 13, 10],
        )
        self.assertEqual(
            object_prediction_list_w_category_mapping[5].category.id, 2)
        self.assertEqual(
            object_prediction_list_w_category_mapping[5].category.name, "car")
        self.assertEqual(
            object_prediction_list_w_category_mapping[2].bbox.to_coco_bbox(),
            [760, 232, 20, 15],
        )
Пример #8
0
    def test_perform_inference_with_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions = mmdet_detection_model.original_predictions

        boxes = original_predictions[0]
        masks = original_predictions[1]

        # find box of first person detection with conf greater than 0.5
        for box in boxes[0]:
            print(len(box))
            if len(box) == 5:
                if box[4] > 0.5:
                    break

        # compare
        self.assertEqual(box[:4].astype("int").tolist(), [336, 123, 346, 139])
        self.assertEqual(len(boxes), 80)
        self.assertEqual(len(masks), 80)
Пример #9
0
    def test_create_original_predictions_from_object_prediction_list_with_mask_output(
        self, ):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions_1 = mmdet_detection_model.original_predictions

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        original_predictions_2 = mmdet_detection_model._create_original_predictions_from_object_prediction_list(
            object_prediction_list)

        # compare
        self.assertEqual(len(original_predictions_1),
                         len(original_predictions_2))  # 2
        self.assertEqual(len(original_predictions_1[0]),
                         len(original_predictions_2[0]))  # 80
        self.assertEqual(len(original_predictions_1[0][2]),
                         len(original_predictions_2[0][2]))  # 25
        self.assertEqual(type(original_predictions_1[0]),
                         type(original_predictions_2[0]))  # list
        self.assertEqual(original_predictions_1[0][2].dtype,
                         original_predictions_2[0][2].dtype)  # float32
        self.assertEqual(original_predictions_1[0][0][0].dtype,
                         original_predictions_2[0][0][0].dtype)  # float32
        self.assertEqual(original_predictions_1[1][0][0].dtype,
                         original_predictions_2[1][0][0].dtype)  # bool
        self.assertEqual(len(original_predictions_1[0][0][0]),
                         len(original_predictions_2[0][0][0]))  # 5
        self.assertEqual(len(original_predictions_1[0][1]),
                         len(original_predictions_1[0][1]))  # 0
        self.assertEqual(original_predictions_1[0][1].shape,
                         original_predictions_1[0][1].shape)  # (0, 5)
Пример #10
0
    def test_load_model(self):
        from sahi.model import MmdetDetectionModel

        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.3,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        self.assertNotEqual(mmdet_detection_model.model, None)
Пример #11
0
    def test_convert_original_predictions_with_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 53)
        self.assertEqual(object_prediction_list[0].category.id, 0)
        self.assertEqual(object_prediction_list[0].category.name, "person")
        self.assertEqual(
            object_prediction_list[0].bbox.to_coco_bbox(),
            [337, 124, 8, 14],
        )
        self.assertEqual(object_prediction_list[1].category.id, 2)
        self.assertEqual(object_prediction_list[1].category.name, "car")
        self.assertEqual(
            object_prediction_list[1].bbox.to_coco_bbox(),
            [657, 204, 13, 10],
        )
        self.assertEqual(object_prediction_list[5].category.id, 2)
        self.assertEqual(object_prediction_list[5].category.name, "car")
        self.assertEqual(
            object_prediction_list[2].bbox.to_coco_bbox(),
            [760, 232, 20, 15],
        )
Пример #12
0
    def test_convert_original_predictions_without_mask_output(self):
        from sahi.model import MmdetDetectionModel

        # init model
        download_mmdet_retinanet_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH,
            confidence_threshold=0.5,
            device=None,
            category_remapping=None,
            load_at_init=True,
        )

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)

        # convert predictions to ObjectPrediction list
        mmdet_detection_model.convert_original_predictions()
        object_prediction_list = mmdet_detection_model.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 100)
        self.assertEqual(object_prediction_list[0].category.id, 2)
        self.assertEqual(object_prediction_list[0].category.name, "car")
        self.assertEqual(
            object_prediction_list[0].bbox.to_coco_bbox(),
            [448, 309, 47, 32],
        )
        self.assertEqual(object_prediction_list[5].category.id, 2)
        self.assertEqual(object_prediction_list[5].category.name, "car")
        self.assertEqual(
            object_prediction_list[5].bbox.to_coco_bbox(),
            [523, 225, 22, 17],
        )
Пример #13
0
    def test_perform_inference(self):
        from sahi.model import MmdetDetectionModel
        from sahi.prediction import PredictionInput

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=mmdet_cascade_mask_rcnn_model_path,
            config_path=mmdet_cascade_mask_rcnn_config_path,
            prediction_score_threshold=0.5,
            device=None,
            category_remapping=None,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"
        image = read_image(image_path)

        # perform inference
        mmdet_detection_model.perform_inference(image)
        original_predictions = mmdet_detection_model.original_predictions

        boxes = original_predictions[0]
        masks = original_predictions[1]

        # find box of first person detection with conf greater than 0.5
        for box in boxes[0]:
            print(len(box))
            if len(box) == 5:
                if box[4] > 0.5:
                    break

        # compare
        self.assertEqual(box[:4].astype("int").tolist(), [336, 123, 346, 139])
        self.assertEqual(len(boxes), 80)
        self.assertEqual(len(masks), 80)
Пример #14
0
    def test_get_sliced_prediction_mmdet(self):
        from sahi.model import MmdetDetectionModel
        from sahi.predict import get_sliced_prediction
        from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model

        # init model
        download_mmdet_cascade_mask_rcnn_model()

        mmdet_detection_model = MmdetDetectionModel(
            model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
            config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
            confidence_threshold=0.3,
            device=None,
            category_remapping=None,
            load_at_init=False,
        )
        mmdet_detection_model.load_model()

        # prepare image
        image_path = "tests/data/small-vehicles1.jpeg"

        slice_height = 512
        slice_width = 512
        overlap_height_ratio = 0.1
        overlap_width_ratio = 0.2
        postprocess_type = "UNIONMERGE"
        match_metric = "IOS"
        match_threshold = 0.5
        class_agnostic = True

        # get sliced prediction
        prediction_result = get_sliced_prediction(
            image=image_path,
            detection_model=mmdet_detection_model,
            slice_height=slice_height,
            slice_width=slice_width,
            overlap_height_ratio=overlap_height_ratio,
            overlap_width_ratio=overlap_width_ratio,
            perform_standard_pred=False,
            postprocess_type=postprocess_type,
            postprocess_match_metric=match_metric,
            postprocess_match_threshold=match_threshold,
            postprocess_class_agnostic=class_agnostic,
        )
        object_prediction_list = prediction_result.object_prediction_list

        # compare
        self.assertEqual(len(object_prediction_list), 24)
        num_person = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "person":
                num_person += 1
        self.assertEqual(num_person, 0)
        num_truck = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "truck":
                num_truck += 2
        self.assertEqual(num_truck, 4)
        num_car = 0
        for object_prediction in object_prediction_list:
            if object_prediction.category.name == "car":
                num_car += 1
        self.assertEqual(num_car, 22)