def test_perform_inference_without_mask_output(self): from sahi.model import MmdetDetectionModel # init model download_mmdet_retinanet_model() mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH, config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH, confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, ) # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) original_predictions = mmdet_detection_model.original_predictions boxes = original_predictions # find box of first car detection with conf greater than 0.5 for box in boxes[2]: print(len(box)) if len(box) == 5: if box[4] > 0.5: break # compare self.assertEqual(box[:4].astype("int").tolist(), [448, 309, 495, 341]) self.assertEqual(len(boxes), 80)
def test_get_prediction(self): from sahi.model import MmdetDetectionModel from sahi.predict import get_prediction from sahi.prediction import PredictionInput from tests.test_utils import ( download_mmdet_cascade_mask_rcnn_model, mmdet_cascade_mask_rcnn_config_path, mmdet_cascade_mask_rcnn_model_path, ) # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=mmdet_cascade_mask_rcnn_model_path, config_path=mmdet_cascade_mask_rcnn_config_path, prediction_score_threshold=0.3, device=None, category_remapping=None, ) mmdet_detection_model.load_model() # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # get full sized prediction prediction_result = get_prediction( image=image, detection_model=mmdet_detection_model, shift_amount=[0, 0], full_image_size=None, merger=None, matcher=None, ) object_prediction_list = prediction_result["object_prediction_list"] # compare self.assertEqual(len(object_prediction_list), 23) num_person = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "person": num_person += 1 self.assertEqual(num_person, 0) num_truck = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "truck": num_truck += 1 self.assertEqual(num_truck, 3) num_car = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "car": num_car += 1 self.assertEqual(num_car, 20)
def test_convert_original_predictions(self): from sahi.model import MmdetDetectionModel from sahi.prediction import PredictionInput # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=mmdet_cascade_mask_rcnn_model_path, config_path=mmdet_cascade_mask_rcnn_config_path, prediction_score_threshold=0.5, device=None, category_remapping=None, ) mmdet_detection_model.load_model() # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) # convert predictions to ObjectPrediction list mmdet_detection_model.convert_original_predictions() object_prediction_list_w_category_mapping = ( mmdet_detection_model.object_prediction_list) # compare self.assertEqual(len(object_prediction_list_w_category_mapping), 53) self.assertEqual( object_prediction_list_w_category_mapping[0].category.id, 0) self.assertEqual( object_prediction_list_w_category_mapping[0].category.name, "person") self.assertEqual( object_prediction_list_w_category_mapping[0].bbox.to_coco_bbox(), [337, 124, 8, 14], ) self.assertEqual( object_prediction_list_w_category_mapping[1].category.id, 2) self.assertEqual( object_prediction_list_w_category_mapping[1].category.name, "car") self.assertEqual( object_prediction_list_w_category_mapping[1].bbox.to_coco_bbox(), [657, 204, 13, 10], ) self.assertEqual( object_prediction_list_w_category_mapping[5].category.id, 2) self.assertEqual( object_prediction_list_w_category_mapping[5].category.name, "car") self.assertEqual( object_prediction_list_w_category_mapping[2].bbox.to_coco_bbox(), [760, 232, 20, 15], )
def test_create_original_predictions_from_object_prediction_list_with_mask_output( self, ): from sahi.model import MmdetDetectionModel # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, ) # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) original_predictions_1 = mmdet_detection_model.original_predictions # convert predictions to ObjectPrediction list mmdet_detection_model.convert_original_predictions() object_prediction_list = mmdet_detection_model.object_prediction_list original_predictions_2 = mmdet_detection_model._create_original_predictions_from_object_prediction_list( object_prediction_list) # compare self.assertEqual(len(original_predictions_1), len(original_predictions_2)) # 2 self.assertEqual(len(original_predictions_1[0]), len(original_predictions_2[0])) # 80 self.assertEqual(len(original_predictions_1[0][2]), len(original_predictions_2[0][2])) # 25 self.assertEqual(type(original_predictions_1[0]), type(original_predictions_2[0])) # list self.assertEqual(original_predictions_1[0][2].dtype, original_predictions_2[0][2].dtype) # float32 self.assertEqual(original_predictions_1[0][0][0].dtype, original_predictions_2[0][0][0].dtype) # float32 self.assertEqual(original_predictions_1[1][0][0].dtype, original_predictions_2[1][0][0].dtype) # bool self.assertEqual(len(original_predictions_1[0][0][0]), len(original_predictions_2[0][0][0])) # 5 self.assertEqual(len(original_predictions_1[0][1]), len(original_predictions_1[0][1])) # 0 self.assertEqual(original_predictions_1[0][1].shape, original_predictions_1[0][1].shape) # (0, 5)
def test_create_original_predictions_from_object_prediction_list(self): from sahi.model import MmdetDetectionModel from sahi.prediction import PredictionInput # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=mmdet_cascade_mask_rcnn_model_path, config_path=mmdet_cascade_mask_rcnn_config_path, prediction_score_threshold=0.5, device=None, category_remapping=None, ) mmdet_detection_model.load_model() # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) original_predictions_1 = mmdet_detection_model.original_predictions # convert predictions to ObjectPrediction list mmdet_detection_model.convert_original_predictions() object_prediction_list = mmdet_detection_model.object_prediction_list original_predictions_2 = mmdet_detection_model._create_original_predictions_from_object_prediction_list( object_prediction_list) # compare self.assertEqual(len(original_predictions_1), len(original_predictions_2)) # 2 self.assertEqual(len(original_predictions_1[0]), len(original_predictions_2[0])) # 80 self.assertEqual(len(original_predictions_1[0][2]), len(original_predictions_2[0][2])) # 25 self.assertEqual(type(original_predictions_1[0]), type(original_predictions_2[0])) # list self.assertEqual(original_predictions_1[0][2].dtype, original_predictions_2[0][2].dtype) # float32 self.assertEqual(original_predictions_1[0][0][0].dtype, original_predictions_2[0][0][0].dtype) # float32 self.assertEqual(original_predictions_1[1][0][0].dtype, original_predictions_2[1][0][0].dtype) # bool self.assertEqual(len(original_predictions_1[0][0][0]), len(original_predictions_2[0][0][0])) # 5 self.assertEqual(len(original_predictions_1[0][1]), len(original_predictions_1[0][1])) # 0 self.assertEqual(original_predictions_1[0][1].shape, original_predictions_1[0][1].shape) # (0, 5)
def test_prediction_input(self): from sahi.prediction import PredictionInput # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # init prediction input prediction_input = PredictionInput(image_list=[image], ) # compare self.assertEqual(len(prediction_input.shift_amount_list), len(prediction_input.image_list))
def test_get_prediction_mmdet(self): from sahi.model import MmdetDetectionModel from sahi.predict import get_prediction from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, confidence_threshold=0.3, device=None, category_remapping=None, ) mmdet_detection_model.load_model() # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # get full sized prediction prediction_result = get_prediction( image=image, detection_model=mmdet_detection_model, shift_amount=[0, 0], full_shape=None, ) object_prediction_list = prediction_result.object_prediction_list # compare self.assertEqual(len(object_prediction_list), 19) num_person = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "person": num_person += 1 self.assertEqual(num_person, 0) num_truck = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "truck": num_truck += 1 self.assertEqual(num_truck, 0) num_car = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "car": num_car += 1 self.assertEqual(num_car, 19)
def test_convert_original_predictions_with_mask_output(self): from sahi.model import MmdetDetectionModel # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, ) # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) # convert predictions to ObjectPrediction list mmdet_detection_model.convert_original_predictions() object_prediction_list = mmdet_detection_model.object_prediction_list # compare self.assertEqual(len(object_prediction_list), 53) self.assertEqual(object_prediction_list[0].category.id, 0) self.assertEqual(object_prediction_list[0].category.name, "person") self.assertEqual( object_prediction_list[0].bbox.to_coco_bbox(), [337, 124, 8, 14], ) self.assertEqual(object_prediction_list[1].category.id, 2) self.assertEqual(object_prediction_list[1].category.name, "car") self.assertEqual( object_prediction_list[1].bbox.to_coco_bbox(), [657, 204, 13, 10], ) self.assertEqual(object_prediction_list[5].category.id, 2) self.assertEqual(object_prediction_list[5].category.name, "car") self.assertEqual( object_prediction_list[2].bbox.to_coco_bbox(), [760, 232, 20, 15], )
def test_get_prediction_yolov5(self): from sahi.model import Yolov5DetectionModel from sahi.predict import get_prediction from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model # init model download_yolov5s6_model() yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, confidence_threshold=0.3, device=None, category_remapping=None, load_at_init=False, ) yolov5_detection_model.load_model() # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # get full sized prediction prediction_result = get_prediction( image=image, detection_model=yolov5_detection_model, shift_amount=[0, 0], full_shape=None, postprocess=None ) object_prediction_list = prediction_result.object_prediction_list # compare self.assertEqual(len(object_prediction_list), 12) num_person = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "person": num_person += 1 self.assertEqual(num_person, 0) num_truck = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "truck": num_truck += 1 self.assertEqual(num_truck, 0) num_car = 0 for object_prediction in object_prediction_list: if object_prediction.category.name == "car": num_car += 1 self.assertEqual(num_car, 12)
def test_convert_original_predictions(self): from sahi.model import Yolov5DetectionModel # init model download_yolov5s6_model() yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, ) # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference yolov5_detection_model.perform_inference(image) # convert predictions to ObjectPrediction list yolov5_detection_model.convert_original_predictions() object_prediction_list = yolov5_detection_model.object_prediction_list # compare self.assertEqual(len(object_prediction_list), 14) self.assertEqual(object_prediction_list[0].category.id, 2) self.assertEqual(object_prediction_list[0].category.name, "car") desired_bbox = [321, 322, 62, 40] predicted_bbox = object_prediction_list[0].bbox.to_coco_bbox() margin = 2 for ind, point in enumerate(predicted_bbox): assert point < desired_bbox[ind] + margin and point > desired_bbox[ ind] - margin self.assertEqual(object_prediction_list[5].category.id, 2) self.assertEqual(object_prediction_list[5].category.name, "car") self.assertEqual( object_prediction_list[5].bbox.to_coco_bbox(), [617, 195, 24, 23], )
def test_convert_original_predictions_without_mask_output(self): from sahi.model import MmdetDetectionModel # init model download_mmdet_retinanet_model() mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH, config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH, confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, ) # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) # convert predictions to ObjectPrediction list mmdet_detection_model.convert_original_predictions() object_prediction_list = mmdet_detection_model.object_prediction_list # compare self.assertEqual(len(object_prediction_list), 100) self.assertEqual(object_prediction_list[0].category.id, 2) self.assertEqual(object_prediction_list[0].category.name, "car") self.assertEqual( object_prediction_list[0].bbox.to_coco_bbox(), [448, 309, 47, 32], ) self.assertEqual(object_prediction_list[5].category.id, 2) self.assertEqual(object_prediction_list[5].category.name, "car") self.assertEqual( object_prediction_list[5].bbox.to_coco_bbox(), [523, 225, 22, 17], )
def test_perform_inference(self): from sahi.model import Yolov5DetectionModel # init model download_yolov5s6_model() yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, ) # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference yolov5_detection_model.perform_inference(image) original_predictions = yolov5_detection_model.original_predictions boxes = original_predictions.xyxy # find box of first car detection with conf greater than 0.5 for box in boxes[0]: if box[5].item() == 2: # if category car if box[4].item() > 0.5: break # compare desired_bbox = [321, 322, 383, 362] predicted_bbox = list(map(int, box[:4].tolist())) margin = 2 for ind, point in enumerate(predicted_bbox): assert point < desired_bbox[ind] + margin and point > desired_bbox[ ind] - margin self.assertEqual(len(original_predictions.names), 80)
def test_perform_inference(self): from sahi.model import MmdetDetectionModel from sahi.prediction import PredictionInput # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=mmdet_cascade_mask_rcnn_model_path, config_path=mmdet_cascade_mask_rcnn_config_path, prediction_score_threshold=0.5, device=None, category_remapping=None, ) mmdet_detection_model.load_model() # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) original_predictions = mmdet_detection_model.original_predictions boxes = original_predictions[0] masks = original_predictions[1] # find box of first person detection with conf greater than 0.5 for box in boxes[0]: print(len(box)) if len(box) == 5: if box[4] > 0.5: break # compare self.assertEqual(box[:4].astype("int").tolist(), [336, 123, 346, 139]) self.assertEqual(len(boxes), 80) self.assertEqual(len(masks), 80)
def test_perform_inference_with_mask_output(self): from sahi.model import MmdetDetectionModel # init model download_mmdet_cascade_mask_rcnn_model() mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, ) # prepare image image_path = "tests/data/small-vehicles1.jpeg" image = read_image(image_path) # perform inference mmdet_detection_model.perform_inference(image) original_predictions = mmdet_detection_model.original_predictions boxes = original_predictions[0] masks = original_predictions[1] # find box of first person detection with conf greater than 0.5 for box in boxes[0]: print(len(box)) if len(box) == 5: if box[4] > 0.5: break # compare self.assertEqual(box[:4].astype("int").tolist(), [336, 123, 346, 139]) self.assertEqual(len(boxes), 80) self.assertEqual(len(masks), 80)
def test_slice_image(self): # read coco file coco_path = "tests/data/coco_utils/terrain1_coco.json" coco = Coco.from_coco_dict_or_path(coco_path) output_file_name = None output_dir = None image_path = "tests/data/coco_utils/" + coco.images[0].file_name slice_image_result = slice_image( image=image_path, coco_annotation_list=coco.images[0].annotations, output_file_name=output_file_name, output_dir=output_dir, slice_height=512, slice_width=512, overlap_height_ratio=0.1, overlap_width_ratio=0.4, min_area_ratio=0.1, out_ext=".png", verbose=False, ) self.assertEqual(len(slice_image_result.images), 18) self.assertEqual(len(slice_image_result.coco_images), 18) self.assertEqual(slice_image_result.coco_images[0].annotations, []) self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296) self.assertEqual( slice_image_result.coco_images[15].annotations[1].bbox, [17, 186, 48, 152], ) image_cv = read_image(image_path) slice_image_result = slice_image( image=image_cv, coco_annotation_list=coco.images[0].annotations, output_file_name=output_file_name, output_dir=output_dir, slice_height=512, slice_width=512, overlap_height_ratio=0.1, overlap_width_ratio=0.4, min_area_ratio=0.1, out_ext=".png", verbose=False, ) self.assertEqual(len(slice_image_result.images), 18) self.assertEqual(len(slice_image_result.coco_images), 18) self.assertEqual(slice_image_result.coco_images[0].annotations, []) self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296) self.assertEqual( slice_image_result.coco_images[15].annotations[1].bbox, [17, 186, 48, 152], ) image_pil = Image.open(image_path) slice_image_result = slice_image( image=image_pil, coco_annotation_list=coco.images[0].annotations, output_file_name=output_file_name, output_dir=output_dir, slice_height=512, slice_width=512, overlap_height_ratio=0.1, overlap_width_ratio=0.4, min_area_ratio=0.1, out_ext=".png", verbose=False, ) self.assertEqual(len(slice_image_result.images), 18) self.assertEqual(len(slice_image_result.coco_images), 18) self.assertEqual(slice_image_result.coco_images[0].annotations, []) self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296) self.assertEqual( slice_image_result.coco_images[15].annotations[1].bbox, [17, 186, 48, 152], )
def get_prediction( image, detection_model, shift_amount: list = [0, 0], full_image_size=None, merger=None, matcher=None, ): """ Function for performing prediction for given image using given detection_model. Arguments: image: str or np.ndarray Location of image or numpy image matrix to slice detection_model: model.DetectionMode shift_amount: List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] full_image_size: List Size of the full image, should be in the form of [height, width] merger: postprocess.PredictionMerger matcher: postprocess.PredictionMatcher Returns: A dict with fields: object_prediction_list: a list of ObjectPrediction durations_in_seconds: a dict containing elapsed times for profiling """ durations_in_seconds = dict() # read image if image is str if isinstance(image, str): image = read_image(image) # get prediction time_start = time.time() detection_model.perform_inference(image) time_end = time.time() - time_start durations_in_seconds["prediction"] = time_end # process prediction time_start = time.time() # works only with 1 batch detection_model.convert_original_predictions( shift_amount=shift_amount, full_image_size=full_image_size, ) object_prediction_list = detection_model.object_prediction_list filtered_object_prediction_list = [ object_prediction for object_prediction in object_prediction_list if object_prediction.score.score > detection_model.prediction_score_threshold ] # merge matching predictions if merger is not None: filtered_object_prediction_list = merger.merge_batch( matcher, filtered_object_prediction_list, merge_type="merge", ) time_end = time.time() - time_start durations_in_seconds["postprocess"] = time_end return { "object_prediction_list": filtered_object_prediction_list, "durations_in_seconds": durations_in_seconds, }
def predict_folder( model_name="MmdetDetectionModel", model_parameters=None, image_dir=None, visual_output_dir=None, pickle_dir=None, crop_dir=None, apply_sliced_prediction: bool = True, slice_height: int = 256, slice_width: int = 256, overlap_height_ratio: float = 0.1, overlap_width_ratio: float = 0.2, visual_bbox_thickness: int = 1, visual_text_size: float = 1, visual_text_thickness: int = 1, ): """ Performs prediction for all present images in given folder. Args: model_name: str Name of the implemented DetectionModel in model.py file. model_parameter: a dict with fields: model_path: str Path for the instance segmentation model weight config_path: str Path for the mmdetection instance segmentation model config file prediction_score_threshold: float All predictions with score < prediction_score_threshold will be discarded. device: str Torch device, "cpu" or "cuda" category_remapping: dict: str to int Remap category ids after performing inference image_dir: str Directory that contain images to be predicted. visual_output_dir: str Directory that prediction visuals are going to be exported. Set to None for no visuals. pickle_dir: str Directory that object prediction list pickles are going to be exported. Set to None for no pickles. crop_dir: str Directory that detected bounding boxes will be cropped&exported. Set to None for no crops. apply_sliced_prediction: bool Set to True if you want sliced prediction, set to False for full prediction. slice_height: int Height of each slice. Defaults to ``256``. slice_width: int Width of each slice. Defaults to ``256``. overlap_height_ratio: float Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window of size 256 yields an overlap of 51 pixels). Default to ``0.2``. overlap_width_ratio: float Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window of size 256 yields an overlap of 51 pixels). Default to ``0.2``. visual_bbox_thickness: int visual_text_size: float visual_text_thickness: int """ image_path_list = list_files(directory=image_dir, contains=[".jpg", ".jpeg", ".png"]) # init model instance DetectionModel = import_class(model_name) detection_model = DetectionModel( model_path=model_parameters["model_path"], config_path=model_parameters["config_path"], prediction_score_threshold=model_parameters[ "prediction_score_threshold"], device=model_parameters["device"], category_mapping=model_parameters["category_mapping"], category_remapping=model_parameters["category_remapping"], ) detection_model.load_model() # iterate over source images for image_path in tqdm(image_path_list): # get filename ( filename_with_extension, filename_without_extension, ) = get_base_filename(path=image_path) # load image image = read_image(image_path) # perform prediction if apply_sliced_prediction: # get sliced prediction prediction_result = get_sliced_prediction( image=image, detection_model=detection_model, slice_height=slice_height, slice_width=slice_width, overlap_height_ratio=overlap_height_ratio, overlap_width_ratio=overlap_width_ratio, ) object_prediction_list = prediction_result[ "object_prediction_list"] else: # get full sized prediction prediction_result = get_prediction( image=image, detection_model=detection_model, shift_amount=[0, 0], full_image_size=None, merger=None, matcher=None, ) object_prediction_list = prediction_result[ "object_prediction_list"] # export prediction boxes if crop_dir: crop_object_predictions( image=image, object_prediction_list=object_prediction_list, output_dir=crop_dir, file_name=filename_without_extension, ) # export prediction list as pickle if pickle_dir: save_path = os.path.join( pickle_dir, filename_without_extension + ".pickle", ) save_pickle(data=object_prediction_list, save_path=save_path) # export visualization if visual_output_dir: visualize_object_predictions( image, object_prediction_list=object_prediction_list, rect_th=visual_bbox_thickness, text_size=visual_text_size, text_th=visual_text_thickness, output_dir=visual_output_dir, file_name=filename_without_extension, )
import matplotlib.pyplot as plt import matplotlib # import required functions, classes from sahi.model import Yolov5DetectionModel from sahi.utils.cv import read_image, visualize_object_predictions, ipython_display from sahi.predict import get_prediction, get_sliced_prediction, predict detection_model = Yolov5DetectionModel( model_path= '/data2/mritime/yolov5/yolov5/runs/train/6epochs_1024_smd_hous1_charls0/weights/best.pt', prediction_score_threshold=0.3, device="cuda", # or 'cpu' ) image_dir = "/data2/mritime/youtube_videos/houston_1080/fr_136290.png" image = read_image(image_dir) # result = get_prediction(image, detection_model) def vis_pred(): result = get_sliced_prediction(image, detection_model, slice_height=1024, slice_width=1024, overlap_height_ratio=0.2, overlap_width_ratio=0.2) visualization_result = visualize_object_predictions( image, object_prediction_list=result["object_prediction_list"], output_dir=None,