Пример #1
0
def get_ground_truths_bbox(img_path, shpfile_path):
    allBoundingBoxes = BoundingBoxes()
    img_name = img_path.split("/")[-1][:-4]
    dataset = rasterio.open(img_path)
    polygons = gpd.read_file(shpfile_path)
    #bound of raster image tiff
    dataset_coords = [(dataset.bounds[0], dataset.bounds[1]),
                      (dataset.bounds[2], dataset.bounds[1]),
                      (dataset.bounds[2], dataset.bounds[3]),
                      (dataset.bounds[0], dataset.bounds[3])]
    dataset_polygon = Polygon(dataset_coords)
    # import ipdb; ipdb.set_trace()
    df_bounds = polygons.bounds
    re = [
        index for index, bbox in enumerate(df_bounds.values.tolist())
        if bbox_is_in_img(bbox, dataset_polygon)
    ]
    for index in tqdm(re):
        # rowmax_colmax
        row_max, col_max = cvt_row_col(dataset, df_bounds, index, "max")
        row_min, col_min = cvt_row_col(dataset, df_bounds, index, "min")
        bb = BoundingBox(img_name,
                         "tree",
                         col_min,
                         row_min,
                         col_max,
                         row_max,
                         CoordinatesType.Absolute,
                         None,
                         BBType.GroundTruth,
                         format=BBFormat.XYX2Y2)
        allBoundingBoxes.addBoundingBox(bb)
    return allBoundingBoxes
Пример #2
0
 def AddDetectedVales(self,row,confidence): 
     #for row in detectedValuesList:
     bb = BoundingBox(
         row.filename,
         row.classname,
         row.xmin,
         row.ymin,
         row.xmax,
         row.ymax,
         CoordinatesType.Absolute, (row.height, row.width),
         BBType.Detected,
         confidence,
         format=BBFormat.XYX2Y2)
     self.allBoundingBoxes.addBoundingBox(bb)
Пример #3
0
 def AddGroundTruth(self,row):
     #for row in groundTruthList:
     bb = BoundingBox(
                 row.filename,
                 row.classname,
                 row.xmin,
                 row.ymin,
                 row.xmax,
                 row.ymax,
                 CoordinatesType.Absolute, 
                 (row.height, row.width),
                 BBType.GroundTruth,
                 format=BBFormat.XYX2Y2)
     self.allBoundingBoxes.addBoundingBox(bb)
def process_one_image(img_path, shfPath, evaluate, model_path, class_list):

    with open(class_list, 'r') as f:
        classes = load_classes(csv.reader(f, delimiter=','))

    labels = {}
    for key, value in classes.items():
        labels[value] = key
    #init model
    model = torch.load(model_path)

    if torch.cuda.is_available():
        model = model.cuda()

    model.training = False
    model.eval()

    big_image_name = img_path.split("/")[-1].split("_")[-3]
    idx_img = img_path.split("/")[-1][:-4][-1]
    dataset = rasterio.open(img_path)
    with rasterio.open(img_path, 'r') as ds:
        arr = ds.read()  # read all raster values
    # read shapefile
    polygons = gpd.read_file(shfPath)
    print(dataset.height, dataset.width, dataset.transform, dataset.crs)

    #convert to 3d axis for process
    rgb1 = np.rollaxis(arr, 0, 3)
    # print(rgb1.shape)
    # img_h, img_w, _ = rgb1.shape
    windows = compute_windows(rgb1, 256, 0.5)
    # Save images to tmpdir
    predicted_boxes = []

    for index, window in enumerate(tqdm(windows)):
        # Crop window and predict
        crop = rgb1[windows[index].indices()]
        # import ipdb; ipdb.set_trace()
        # # Crop is RGB channel order, change to BGR
        # crop = crop[..., ::-1]
        # crop = cv2.cvtColor(crop, cv2.COLOR_RGB2BGR)
        # pred_img1 = get_prediction(model, crop, confidence=0.7)
        boxes = get_prediction(model, crop, confidence=0.7)
        if boxes is None:
            continue
        boxes['xmin'] = pd.to_numeric(boxes['xmin'])
        boxes['ymin'] = pd.to_numeric(boxes['ymin'])
        boxes['xmax'] = pd.to_numeric(boxes['xmax'])
        boxes['ymax'] = pd.to_numeric(boxes['ymax'])

        # transform coordinates to original system
        xmin, ymin, w, h = windows[index].getRect()  #(x,y,w,h)
        boxes.xmin = boxes.xmin + xmin
        boxes.xmax = boxes.xmax + xmin
        boxes.ymin = boxes.ymin + ymin
        boxes.ymax = boxes.ymax + ymin

        predicted_boxes.append(boxes)
        # if index == 3:    #break to test some first images
        #   break

    predicted_boxes = pd.concat(predicted_boxes)
    # Apply NMS
    with tf.Session() as sess:
        print(
            "{} predictions in overlapping windows, applying non-max supression". \
            format(predicted_boxes.shape[0]))
        new_boxes, new_scores, new_labels = non_max_suppression(
            sess,
            predicted_boxes[["xmin", "ymin", "xmax", "ymax"]].values,
            predicted_boxes.score.values,
            predicted_boxes.label.values,
            max_output_size=predicted_boxes.shape[0],
            iou_threshold=0.5)
        # import ipdb; ipdb.set_trace()

        # Recreate box dataframe
        image_detections = np.concatenate([
            new_boxes,
            np.expand_dims(new_scores, axis=1),
            np.expand_dims(new_labels, axis=1)
        ],
                                          axis=1)

    df = pd.DataFrame(
        image_detections,
        columns=["xmin", "ymin", "xmax", "ymax", "score", "label"])
    # import ipdb; ipdb.set_trace()
    # df.label = df.label.str.decode("utf-8")
    if evaluate:
        # calcualte precision, recall for ground truths and predict bbox
        allBoundingBoxes = get_ground_truths_bbox(img_path, shfPath)
        for ele in tqdm(df.values.tolist()):
            bb = BoundingBox(
                img_path.split("/")[-1][:-4],
                ele[5].decode(),  # label
                ele[0],  # x_min
                ele[1],  # y_min
                ele[2],  # x_max
                ele[3],  # y_max
                CoordinatesType.Absolute,
                None,
                BBType.Detected,
                float(ele[4]),  # confidence
                format=BBFormat.XYX2Y2)
            allBoundingBoxes.addBoundingBox(bb)

        evaluator = Evaluator()
        metricsPerClass = evaluator.GetPascalVOCMetrics(
            allBoundingBoxes,  # Object containing all bounding boxes (ground truths and detections)
            IOUThreshold=0.5,  # IOU threshold
            method=MethodAveragePrecision.EveryPointInterpolation
        )  # As the official matlab code
        print("Precision values per class:\n")
        # Loop through classes to obtain their metrics
        for mc in metricsPerClass:
            # Get metric values per each class
            c = mc['class']
            precision = mc['precision']
            recall = mc['recall']
            average_precision = mc['AP']
            total_TP = mc['total TP']
            total_FP = mc['total FP']
            total_groundTruths = mc['total positives']
            # Print AP per class
            print("Precision: {}: {}".format(c,
                                             total_TP / (total_TP + total_FP)))
            print('Recall: {}: {}'.format(c, total_TP / total_groundTruths))

    df['geometry'] = df.apply(lambda x: convert_xy_tif(x, dataset), axis=1)
    df_res = gpd.GeoDataFrame(df[["xmin", "ymin", "xmax", "ymax", "geometry"]],
                              geometry='geometry')
    # import ipdb; ipdb.set_trace()
    df_res.to_file('./demo/output_pred/with_shapely_pred.shp',
                   driver='ESRI Shapefile')
    print("-----------Done--------------")
Пример #5
0
def getBoundingBoxes():
    """Read txt files containing bounding boxes (ground truth and detections)."""
    allBoundingBoxes = BoundingBoxes()
    import glob
    import os
    # Read ground truths
    currentPath = os.path.dirname(os.path.abspath(__file__))
    folderGT = os.path.join(currentPath, 'groundtruths')
    os.chdir(folderGT)
    files = glob.glob("*.txt")
    files.sort()
    # Class representing bounding boxes (ground truths and detections)
    allBoundingBoxes = BoundingBoxes()
    # Read GT detections from txt file
    # Each line of the files in the groundtruths folder represents a ground truth bounding box
    # (bounding boxes that a detector should detect)
    # Each value of each line is  "class_id, x, y, width, height" respectively
    # Class_id represents the class of the bounding box
    # x, y represents the most top-left coordinates of the bounding box
    # x2, y2 represents the most bottom-right coordinates of the bounding box
    for f in files:
        nameOfImage = f.replace(".txt", "")
        fh1 = open(f, "r")
        for line in fh1:
            line = line.replace("\n", "")
            if line.replace(' ', '') == '':
                continue
            splitLine = line.split(" ")
            idClass = splitLine[0]  # class
            x = float(splitLine[1])  # confidence
            y = float(splitLine[2])
            w = float(splitLine[3])
            h = float(splitLine[4])
            bb = BoundingBox(nameOfImage,
                             idClass,
                             x,
                             y,
                             w,
                             h,
                             CoordinatesType.Absolute, (200, 200),
                             BBType.GroundTruth,
                             format=BBFormat.XYWH)
            allBoundingBoxes.addBoundingBox(bb)
        fh1.close()
    # Read detections
    folderDet = os.path.join(currentPath, 'detections')
    os.chdir(folderDet)
    files = glob.glob("*.txt")
    files.sort()
    # Read detections from txt file
    # Each line of the files in the detections folder represents a detected bounding box.
    # Each value of each line is  "class_id, confidence, x, y, width, height" respectively
    # Class_id represents the class of the detected bounding box
    # Confidence represents confidence (from 0 to 1) that this detection belongs to the class_id.
    # x, y represents the most top-left coordinates of the bounding box
    # x2, y2 represents the most bottom-right coordinates of the bounding box
    for f in files:
        # nameOfImage = f.replace("_det.txt","")
        nameOfImage = f.replace(".txt", "")
        # Read detections from txt file
        fh1 = open(f, "r")
        for line in fh1:
            line = line.replace("\n", "")
            if line.replace(' ', '') == '':
                continue
            splitLine = line.split(" ")
            idClass = splitLine[0]  # class
            confidence = float(splitLine[1])  # confidence
            x = float(splitLine[2])
            y = float(splitLine[3])
            w = float(splitLine[4])
            h = float(splitLine[5])
            bb = BoundingBox(nameOfImage,
                             idClass,
                             x,
                             y,
                             w,
                             h,
                             CoordinatesType.Absolute, (200, 200),
                             BBType.Detected,
                             confidence,
                             format=BBFormat.XYWH)
            allBoundingBoxes.addBoundingBox(bb)
        fh1.close()
    return allBoundingBoxes
Пример #6
0
                                                shuffle=False,
                                                num_workers=4,
                                                collate_fn=utils.collate_fn)
 model = init_model()
 imgs = glob.glob(os.path.join('validating_data/', "*"))
 # rand_img = random.sample(imgs, 1)
 # import ipdb; ipdb.set_trace()
 myBoundingBoxes = BoundingBoxes()
 for idx in range(len(dataset_test)):
     img, target = dataset_test.__getitem__(idx)
     for i in range(len(target['boxes'])):
         gt_boundingBox = BoundingBox(
             imageName=target["img_name"],
             classId=CLASS_NAMES[target['labels'][i].item()],
             x=target['boxes'][i][0].item(),
             y=target['boxes'][i][1].item(),
             w=target['boxes'][i][2].item(),
             h=target['boxes'][i][3].item(),
             typeCoordinates=CoordinatesType.Absolute,
             bbType=BBType.GroundTruth,
             format=BBFormat.XYX2Y2)
         myBoundingBoxes.addBoundingBox(gt_boundingBox)
     image_path = glob.glob(
         os.path.join('validating_data/',
                      "{}.jpg".format(target["img_name"])))[0]
     print("predict {}".format(target["img_name"]))
     #failed at predict 000003085
     pred_boxes, pred_class, pred_score = get_prediction(
         model, image_path, 0.5)  # Get predictions
     # import ipdb; ipdb.set_trace()
     for idx_detect in range(len(pred_boxes)):
         detected_boundingBox = BoundingBox(
Пример #7
0
def getBoundingBoxes(directory,
                     isGT,
                     bbFormat,
                     coordType,
                     allBoundingBoxes=None,
                     allClasses=None,
                     imgSize=(0, 0)):
    """Read txt files containing bounding boxes (ground truth and detections)."""
    if allBoundingBoxes is None:
        allBoundingBoxes = BoundingBoxes()
    if allClasses is None:
        allClasses = []
    # Read ground truths
    os.chdir(directory)
    files = glob.glob("*.txt")
    files.sort()
    # Read GT detections from txt file
    # Each line of the files in the groundtruths folder represents a ground truth bounding box
    # (bounding boxes that a detector should detect)
    # Each value of each line is  "class_id, x, y, width, height" respectively
    # Class_id represents the class of the bounding box
    # x, y represents the most top-left coordinates of the bounding box
    # x2, y2 represents the most bottom-right coordinates of the bounding box
    for f in files:
        nameOfImage = f.replace(".txt", "")
        fh1 = open(f, "r")
        for line in fh1:
            line = line.replace("\n", "")
            if line.replace(' ', '') == '':
                continue
            splitLine = line.split(" ")
            if isGT:
                # idClass = int(splitLine[0]) #class
                idClass = (splitLine[0])  # class
                x = float(splitLine[1])
                y = float(splitLine[2])
                w = float(splitLine[3])
                h = float(splitLine[4])
                bb = BoundingBox(nameOfImage,
                                 idClass,
                                 x,
                                 y,
                                 w,
                                 h,
                                 coordType,
                                 imgSize,
                                 BBType.GroundTruth,
                                 format=bbFormat)
            else:
                # idClass = int(splitLine[0]) #class
                idClass = (splitLine[0])  # class
                confidence = float(splitLine[1])
                x = float(splitLine[2])
                y = float(splitLine[3])
                w = float(splitLine[4])
                h = float(splitLine[5])
                bb = BoundingBox(nameOfImage,
                                 idClass,
                                 x,
                                 y,
                                 w,
                                 h,
                                 coordType,
                                 imgSize,
                                 BBType.Detected,
                                 confidence,
                                 format=bbFormat)
            allBoundingBoxes.addBoundingBox(bb)
            if idClass not in allClasses:
                allClasses.append(idClass)
        fh1.close()
    return allBoundingBoxes, allClasses
Пример #8
0
def getBoundingBoxes(directory,
                     isGT,
                     bbFormat,
                     coordType,
                     allBoundingBoxes=None,
                     allClasses=None,
                     imgSize=(0, 0)):
    """Read txt files containing bounding boxes (ground truth and detections)."""
    if allBoundingBoxes is None:
        allBoundingBoxes = BoundingBoxes()
    if allClasses is None:
        allClasses = []
    # Read ground truths
    # files = [os.path.join(directory, f) for f in os.listdir(directory)]
    # files.sort()
    with open(directory, 'r', encoding="utf-8") as f:
        data_info = json.load(f)
    # Read GT detections from txt file
    # Each line of the files in the groundtruths folder represents a ground truth bounding box
    # (bounding boxes that a detector should detect)
    # Each value of each line is  "class_id, x, y, width, height" respectively
    # Class_id represents the class of the bounding box
    # x, y represents the most top-left coordinates of the bounding box
    # x2, y2 represents the most bottom-right coordinates of the bounding box
    for f, v in data_info.items():
        nameOfImage = str(f.split('/')[-1]).split('.')[0]
        # fh1 = open(f, "r")
        # for line in fh1:
        #     line = line.replace("\n", "")
        #     if line.replace(' ', '') == '':
        #         continue
        #     splitLine = line.split(" ")
        if isGT:
            # idClass = int(splitLine[0]) #class
            idClass = (v[0])  # class
            x = float(0)
            y = float(0)
            w = float(v[1])
            h = float(v[2])
            bb = BoundingBox(nameOfImage,
                             idClass,
                             x,
                             y,
                             w,
                             h,
                             coordType,
                             imgSize,
                             BBType.GroundTruth,
                             format=bbFormat)
        else:
            # idClass = int(splitLine[0]) #class
            idClass = (v[0])  # class
            confidence = float(v[1])
            x = float(0)
            y = float(0)
            w = float(v[2])
            h = float(v[3])
            bb = BoundingBox(nameOfImage,
                             idClass,
                             x,
                             y,
                             w,
                             h,
                             coordType,
                             imgSize,
                             BBType.Detected,
                             confidence,
                             format=bbFormat)
        allBoundingBoxes.addBoundingBox(bb)
        if idClass not in allClasses:
            allClasses.append(idClass)
        # fh1.close()
    return allBoundingBoxes, allClasses
Пример #9
0
def get_bboxes_and_classes(ground_truth_dir_path, prediction_dir_path,
                           score_threshold, ios_threshold, iou_threshold):

    gt_dict = {}
    for file_path in glob.glob(os.path.join(ground_truth_dir_path, '*.txt')):
        image_name = os.path.splitext(os.path.basename(file_path))[0]
        gt_dict[image_name] = []
        with open(file_path, 'r') as f:
            for line in f:
                obj_dict = {}
                class_name, sx, sy, ex, ey = line.split('\t')
                obj_dict['class_name'] = class_name
                obj_dict['bbox'] = [float(sx), float(sy), float(ex), float(ey)]
                gt_dict[image_name].append(obj_dict)

    predictions_dict = {}
    for file_path in glob.glob(os.path.join(prediction_dir_path, '*.txt')):
        image_name = os.path.splitext(os.path.basename(file_path))[0]
        predictions_dict[image_name] = []
        with open(file_path, 'r') as f:
            for line in f:
                obj_dict = {}
                class_name, score, sx, sy, ex, ey = line.split('\t')
                obj_dict['class_name'] = class_name
                obj_dict['score'] = float(score)
                obj_dict['bbox'] = [float(sx), float(sy), float(ex), float(ey)]
                predictions_dict[image_name].append(obj_dict)

    allBoundingBoxes = BoundingBoxes()

    for img_filename, prediction in predictions_dict.items():

        # BBox of groundTruth
        true_annotation = gt_dict[img_filename]
        for obj in true_annotation:
            bbox = obj["bbox"]
            class_name = obj['class_name']
            bb = BoundingBox(img_filename,
                             class_name,
                             bbox[0],
                             bbox[1],
                             bbox[2],
                             bbox[3],
                             CoordinatesType.Absolute,
                             None,
                             BBType.GroundTruth,
                             format=BBFormat.XYX2Y2)
            allBoundingBoxes.addBoundingBox(bb)

        # Non Maximum Suppression of predictions
        classes = []
        boxes = []
        scores = []
        for obj in prediction:
            if obj["score"] >= score_threshold:
                classes.append(obj["class_name"])
                boxes.append(obj["bbox"])
                scores.append(obj["score"])

        classes, boxes, scores = non_maximum_suppression(
            classes, boxes, scores, ios_threshold, iou_threshold)
        for class_name, bbox, score in zip(classes, boxes, scores):
            bb = BoundingBox(img_filename,
                             class_name,
                             bbox[0],
                             bbox[1],
                             bbox[2],
                             bbox[3],
                             CoordinatesType.Absolute,
                             None,
                             BBType.Detected,
                             score,
                             format=BBFormat.XYX2Y2)
            allBoundingBoxes.addBoundingBox(bb)

    return allBoundingBoxes
Пример #10
0
    def on_batch_end(self, last_output, last_target, **kwargs):
        bbox_gt_batch, class_gt_batch = last_target
        class_pred_batch, bbox_pred_batch = last_output[:2]

        self.images_per_batch = self.images_per_batch if self.images_per_batch > 0 else class_pred_batch.shape[
            0]
        for bbox_gt, class_gt, clas_pred, bbox_pred in \
                list(zip(bbox_gt_batch, class_gt_batch, class_pred_batch, bbox_pred_batch))[: self.images_per_batch]:

            #            bbox_pred, scores, preds = process_output(clas_pred, bbox_pred, self.anchors, self.detect_thresh)
            out = process_output(clas_pred, bbox_pred, self.anchors,
                                 self.detect_thresh)
            bbox_pred, scores, preds = out['bbox_pred'], out['scores'], out[
                'preds']

            if bbox_pred is None:  # or len(preds) > 3 * len(bbox_gt):
                continue

            #image = np.zeros((512, 512, 3), np.uint8)

            # if the number is to hight evaluation is very slow
            total_nms_examples = len(class_gt) * 3
            bbox_pred = bbox_pred[:total_nms_examples]
            scores = scores[:total_nms_examples]
            preds = preds[:total_nms_examples]
            to_keep = nms(bbox_pred, scores, self.nms_thresh)
            bbox_pred, preds, scores = bbox_pred[to_keep].cpu(
            ), preds[to_keep].cpu(), scores[to_keep].cpu()

            t_sz = torch.Tensor([(self.size, self.size)])[None].cpu()
            bbox_gt = bbox_gt[np.nonzero(class_gt)].squeeze(dim=1).cpu()
            class_gt = class_gt[class_gt > 0]
            # change gt from x,y,x2,y2 -> x,y,w,h
            if (bbox_gt.shape[-1] == 4):
                bbox_gt[:, 2:] = bbox_gt[:, 2:] - bbox_gt[:, :2]

            bbox_gt = to_np(rescale_boxes(bbox_gt, t_sz))
            bbox_pred = to_np(rescale_boxes(bbox_pred, t_sz))
            # change from center to top left
            if (bbox_gt.shape[-1] == 4):
                bbox_pred[:, :2] = bbox_pred[:, :2] - bbox_pred[:, 2:] / 2

            class_gt = to_np(class_gt) - 1
            preds = to_np(preds)
            scores = to_np(scores)

            for box, cla in zip(bbox_gt, class_gt):
                if (bbox_gt.shape[-1] == 4):
                    temp = BoundingBox(
                        imageName=str(self.imageCounter),
                        classId='Mit',
                        x=box[0],
                        y=box[1],
                        w=box[2],
                        h=box[3],
                        typeCoordinates=CoordinatesType.Absolute,
                        bbType=BBType.GroundTruth,
                        format=BBFormat.XYWH,
                        imgSize=(self.size, self.size))

                    self.boundingObjects.addBoundingBox(temp)

                else:
                    temp = BoundingCircle(
                        imageName=str(self.imageCounter),
                        classId='Mit',
                        x=box[0],
                        y=box[1],
                        r=box[2],
                        typeCoordinates=CoordinatesType.Absolute,
                        bbType=BBType.GroundTruth,
                        imgSize=(self.size, self.size))

                    self.boundingObjects.addBoundingCircle(temp)

            # to reduce math complexity take maximal three times the number of gt boxes
            num_boxes = len(bbox_gt) * 3
            for box, cla, scor in list(zip(bbox_pred, preds,
                                           scores))[:num_boxes]:
                if (bbox_gt.shape[-1] == 4):
                    temp = BoundingBox(
                        imageName=str(self.imageCounter),
                        classId='Mit',
                        x=box[0],
                        y=box[1],
                        w=box[2],
                        h=box[3],
                        typeCoordinates=CoordinatesType.Absolute,
                        classConfidence=scor,
                        bbType=BBType.Detected,
                        format=BBFormat.XYWH,
                        imgSize=(self.size, self.size))

                    self.boundingObjects.addBoundingBox(temp)
                else:
                    temp = BoundingCircle(
                        imageName=str(self.imageCounter),
                        classId='Mit',
                        x=box[0],
                        y=box[1],
                        r=box[2],
                        typeCoordinates=CoordinatesType.Absolute,
                        classConfidence=scor,
                        bbType=BBType.Detected,
                        imgSize=(self.size, self.size))

                    self.boundingObjects.addBoundingCircle(temp)

            #image = self.boundingObjects.drawAllBoundingBoxes(image, str(self.imageCounter))
            self.imageCounter += 1

        if len(last_output) > 3:  # STN use case
            _, bbox_pred_batch, class_pred_batch = last_output[:3]

            self.images_per_batch = self.images_per_batch if self.images_per_batch > 0 else class_pred_batch.shape[
                0]
            for bbox_gt, class_gt, clas_pred, bbox_pred in \
                    list(zip(bbox_gt_batch, class_gt_batch, class_pred_batch, bbox_pred_batch))[: self.images_per_batch]:

                #            bbox_pred, scores, preds = process_output(clas_pred, bbox_pred, self.anchors, self.detect_thresh)
                out = process_output(clas_pred, bbox_pred, self.anchors,
                                     self.detect_thresh)
                bbox_pred, scores, preds = out['bbox_pred'], out[
                    'scores'], out['preds']

                if bbox_pred is None:  # or len(preds) > 3 * len(bbox_gt):
                    continue

                #image = np.zeros((512, 512, 3), np.uint8)

                # if the number is to hight evaluation is very slow
                total_nms_examples = len(class_gt) * 3
                bbox_pred = bbox_pred[:total_nms_examples]
                scores = scores[:total_nms_examples]
                preds = preds[:total_nms_examples]
                to_keep = nms(bbox_pred, scores, self.nms_thresh)
                bbox_pred, preds, scores = bbox_pred[to_keep].cpu(
                ), preds[to_keep].cpu(), scores[to_keep].cpu()

                t_sz = torch.Tensor([(self.size, self.size)])[None].cpu()
                bbox_gt = bbox_gt[np.nonzero(class_gt)].squeeze(dim=1).cpu()
                class_gt = class_gt[class_gt > 0]
                # change gt from x,y,x2,y2 -> x,y,w,h
                if (bbox_gt.shape[-1] == 4):
                    bbox_gt[:, 2:] = bbox_gt[:, 2:] - bbox_gt[:, :2]

                bbox_gt = to_np(rescale_boxes(bbox_gt, t_sz))
                bbox_pred = to_np(rescale_boxes(bbox_pred, t_sz))
                # change from center to top left
                if (bbox_gt.shape[-1] == 4):
                    bbox_pred[:, :2] = bbox_pred[:, :2] - bbox_pred[:, 2:] / 2

                class_gt = to_np(class_gt) - 1
                preds = to_np(preds)
                scores = to_np(scores)

                for box, cla in zip(bbox_gt, class_gt):
                    if (bbox_gt.shape[-1] == 4):
                        temp = BoundingBox(
                            imageName=str(self.imageCounter),
                            classId='Mit',
                            x=box[0],
                            y=box[1],
                            w=box[2],
                            h=box[3],
                            typeCoordinates=CoordinatesType.Absolute,
                            bbType=BBType.GroundTruth,
                            format=BBFormat.XYWH,
                            imgSize=(self.size, self.size))

                        self.boundingObjectsSTN.addBoundingBox(temp)

                    else:
                        temp = BoundingCircle(
                            imageName=str(self.imageCounter),
                            classId='Mit',
                            x=box[0],
                            y=box[1],
                            r=box[2],
                            typeCoordinates=CoordinatesType.Absolute,
                            bbType=BBType.GroundTruth,
                            imgSize=(self.size, self.size))

                        self.boundingObjectsSTN.addBoundingCircle(temp)

                # to reduce math complexity take maximal three times the number of gt boxes
                num_boxes = len(bbox_gt) * 3
                for box, cla, scor in list(zip(bbox_pred, preds,
                                               scores))[:num_boxes]:
                    if (bbox_gt.shape[-1] == 4):
                        temp = BoundingBox(
                            imageName=str(self.imageCounter),
                            classId='Mit',
                            x=box[0],
                            y=box[1],
                            w=box[2],
                            h=box[3],
                            typeCoordinates=CoordinatesType.Absolute,
                            classConfidence=scor,
                            bbType=BBType.Detected,
                            format=BBFormat.XYWH,
                            imgSize=(self.size, self.size))

                        self.boundingObjectsSTN.addBoundingBox(temp)
                    else:
                        temp = BoundingCircle(
                            imageName=str(self.imageCounter),
                            classId='Mit',
                            x=box[0],
                            y=box[1],
                            r=box[2],
                            typeCoordinates=CoordinatesType.Absolute,
                            classConfidence=scor,
                            bbType=BBType.Detected,
                            imgSize=(self.size, self.size))

                        self.boundingObjectsSTN.addBoundingCircle(temp)