Ejemplo n.º 1
0
def plot_images(dataset, patient_id, data_arr, gt_arr, pred_arr, output_dir,
                bbox_flag, bbox_metrics, dice):
    """
    Plots 15 different views of a given patient imaging data.
    # bbox metrics from distance calculation
    Args:
        dataset (str): Name of dataset.
        patient_id (str): Unique patient id.
        path_to_data (str): Path to nrrd file containing the image.
        path_to_mask_list (list) List of strings paths to nrrd files containing contours.
            Files must be in named following the naming convention. At least one mask(contour) should be provided as this is used to set the viewing bounds ofthe image. If multiple masks are provided, they are added up and the resultis used to set the bounds. Make sure to pass the masks in the same order(for each patient) so that the contour colors do not flip on you.
        output_dir (str): Path to folder where the png will be saved
        bbox_flag (bool): Boolean whether to show bounding box or not. If True,
        it will be set based on the viewing bounds.
    Returns:
        None
    Raises:
        Exception if an error occurs.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    try:
        gt_bbox_metrics = bbox_metrics["ground_truth_bbox_metrics"]
        pred_bbox_metrics = bbox_metrics["prediction_bbox_metrics"]
        pred_arr = utils.threshold(pred_arr)
        mask_arr_list = [gt_arr, pred_arr]
        mask_list_names = ["gt", "pred"]
        # bbox and centroid will be calculated based on gt_arr only
        # combined = utils.combine_masks(mask_arr_list)
        gt_bbox = utils.get_bbox(gt_arr)
        pred_bbox = utils.get_bbox(pred_arr)
        # decided to use geometric center as opposed to center of mass to show the actual center of the bounding box - although still called com, it is the geometric center
        # com = ndimage.measurements.center_of_mass(combined)
        # print(com)
        # com = getClosestSlice(com)
        # print(com)
        com_gt = getClosestSlice(
            (gt_bbox_metrics["Z"]["center"], gt_bbox_metrics["Y"]["center"],
             gt_bbox_metrics["X"]["center"]))
        com_pred = getClosestSlice((pred_bbox_metrics["Z"]["center"],
                                    pred_bbox_metrics["Y"]["center"],
                                    pred_bbox_metrics["X"]["center"]))

        assert gt_bbox_metrics["Z"]["min"] == gt_bbox[0], "bbox calc incorrect"
        assert gt_bbox_metrics["Z"]["max"] == gt_bbox[1], "bbox calc incorrect"
        assert gt_bbox_metrics["Y"]["min"] == gt_bbox[2], "bbox calc incorrect"
        assert gt_bbox_metrics["Y"]["max"] == gt_bbox[3], "bbox calc incorrect"
        assert gt_bbox_metrics["X"]["min"] == gt_bbox[4], "bbox calc incorrect"
        assert gt_bbox_metrics["X"]["max"] == gt_bbox[5], "bbox calc incorrect"
        # plot
        plot_figure(dataset, patient_id, data_arr, mask_arr_list,
                    mask_list_names, com_gt, com_pred, [gt_bbox, pred_bbox],
                    bbox_flag, output_dir, bbox_metrics["distance"], dice)

        print("{}_{} saved".format(dataset, patient_id))
    except Exception as e:
        print("Error in {}_{}, {}".format(dataset, patient_id, e))
Ejemplo n.º 2
0
    def on_btn_pre_clicked(self):
        if self.index <= 0:
            return
        self.index -= 1
        # print label message
        if '\\' in self.anno_list[self.index]:
            scene_id, instance = self.anno_list[self.index].split('\\')
        else:
            scene_id, instance = self.anno_list[self.index].split('/')
        lable = self.anno_list[self.index].split('_')[-2][5:]
        if lable in self.ins_label.keys():
            message = scene_id + '/' + instance + '\t' + self.ins_label[lable]
        else:
            message = scene_id + '/' + instance
        self.label_label.setText(message)

        self.pc = np.load(
            os.path.join(self.loadroot, self.anno_list[self.index]))
        pmin, pmax = get_bbox(self.pc)
        center = (pmin + pmax) / 2
        self.pc -= center
        self.viewer.clear()
        self.viewer.load(self.pc)
        self.phi = 0
        self.theta = 0
        self.rt = np.eye(4)
        center = np.mean(self.pc, axis=0)
        self.rt[:3, 3] = -center[:3]
        self.init_viewer()
Ejemplo n.º 3
0
def read_from_dir(PATH):

    x, y = [], []
    # PATH = "./test_dataset/test_data/"
    list_of_imgs = os.listdir(PATH + "images")

    for image in list_of_imgs:

        if image[-3:] != "jpg": continue

        img = cv2.imread(PATH + "images/" + image)
        img_32x32 = preprocess(img)
        resolution = img_32x32.shape[0:2]
        x.append(img_32x32)

        mat = loadmat(PATH + "annotations/" + image[:-3] + "mat")

        bbox_of_all_hands = []
        for hand in mat["boxes"][0]:
            a = hand[0, 0][0][0]
            b = hand[0, 0][1][0]
            c = hand[0, 0][2][0]
            d = hand[0, 0][3][0]
            rotated_box_coords = a, b, c, d
            bbox_coords = get_bbox(rotated_box_coords)
            bbox_of_all_hands.append(bbox_coords)
        yolo_op = convert_to_yolo_output(bbox_of_all_hands, resolution)
        y.append(yolo_op)

    x = np.array(x)
    y = np.array(y)
    return x, y
Ejemplo n.º 4
0
def extract_bbox(mask_dir, set_):
    gt_bboxes = []
    basedir = os.path.join(mask_dir, str(set_).zfill(2))
    mask_files = [
        os.path.join(basedir, f) for f in os.listdir(basedir)
        if 'augment' not in f
    ]
    for mask_file in tqdm(mask_files):
        gt_bboxes.append(get_bbox(mask_file))
    gt_bboxes = np.array(gt_bboxes)
    return gt_bboxes
Ejemplo n.º 5
0
    def on_btn_next_clicked(self):
        if self.index >= len(self.anno_list):
            return
        if '\\' in self.anno_list[self.index]:
            scene_id, instance = self.anno_list[self.index].split('\\')
        else:
            scene_id, instance = self.anno_list[self.index].split('/')

        self.anno_json[scene_id][instance]['annotated'] = True
        self.anno_json[scene_id][instance][
            'symmetric'] = self.checkBox_Sym.isChecked()
        self.anno_json[scene_id][instance][
            'delete'] = self.checkBox_del.isChecked()
        self.save_6d(scene_id, instance)
        self.index += 1
        if self.index >= len(self.anno_list):
            self.index = len(self.anno_list)
            QMessageBox.information(self, 'congratulations',
                                    'annotation done!', QMessageBox.Ok)
            return

        if '\\' in self.anno_list[self.index]:
            scene_id, instance = self.anno_list[self.index].split('\\')
        else:
            scene_id, instance = self.anno_list[self.index].split('/')
        lable = self.anno_list[self.index].split('_')[-2][5:]
        if lable in self.ins_label.keys():
            message = scene_id + '/' + instance + '\t' + self.ins_label[lable]
        else:
            message = scene_id + '/' + instance
        self.label_label.setText(message)
        pc_path = os.path.join(self.loadroot, self.anno_list[self.index])
        self.pc = np.load(pc_path)
        pmin, pmax = get_bbox(self.pc)
        center = (pmin + pmax) / 2
        self.pc -= center
        self.viewer.clear()
        self.viewer.load(self.pc)
        self.phi = 0
        self.theta = 0
        self.rt = np.eye(4)
        self.rt[:3, 3] = -center[:3]
        self.init_viewer()
        self.save_conf()
        self.checkBox_Sym.setChecked(False)
        self.checkBox_del.setChecked(False)
Ejemplo n.º 6
0
def predict_one_img(model, img_path, img_shape):
    h, w = img_shape
    img = load_img(img_path)
    # crop
    base_fn = os.path.basename(img_path)
    box = get_bbox(base_fn)
    x0, y0, x1, y1 = expand_bb(img, box)
    if not (x0 >= x1 or y0 >= y1):
        tmp_box = (x0, y0, x1, y1)
        img.crop(tmp_box)
    # pil use w,h
    img = img.resize((w, h))  # default nearest
    img = img_to_array(img)
    img = preprocess_func(img)
    flip_img = np.fliplr(img)
    res = model.predict(np.array([img, flip_img]))
    res = np.mean(res, axis=0)
    return res
Ejemplo n.º 7
0
def main():
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    # path to input video
    cap = utils.load_video(in_vid)

    frame_id = 0

    # open file for reading ROI location
    f = open(in_txt, 'r')

    # get rid of header and get first bbox
    _ = f.readline()

    while cap.isOpened():
        # read frame and update bounding box
        _, img = cap.read()
        line = f.readline()
        if not _ or not line:
            break
        bbox = utils.get_bbox(line)
        utils.draw_box(img, bbox, (0, 0, 255))

        # cv2.imshow('frame', img[bbox[1]+1:bbox[1]+bbox[3], bbox[0]+1:bbox[0]+bbox[2]])
        cv2.imshow('frame', img)
        if cv2.waitKey(1) & 0xff == ord('q'):
            break

        # update frame_id and write to file
        if is_print:
            print('{}{}.jpg'.format(out_dir, frame_id))
            cv2.imwrite(
                '{}{}.jpg'.format(out_dir, frame_id),
                img[bbox[1] + 1:bbox[1] + bbox[3],
                    bbox[0] + 1:bbox[0] + bbox[2]])
        frame_id += 1
        # print('{}, {:1.0f}, {:1.0f}, {:1.0f}, {:1.0f}'
        #       .format(frame_id, bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))
    cv2.waitKey(5)
    cap.release()
    cv2.destroyAllWindows()

    f.close()
Ejemplo n.º 8
0
    def on_btn_start_clicked(self):
        if not len(self.anno_json) and not len(self.anno_list):
            QMessageBox.warning(self, 'warning', 'load point clouds first!',
                                QMessageBox.Cancel)
            return
        #print(self.anno_json)
        self.init_instance()
        if len(self.anno_list) == 0:
            QMessageBox.information(self, 'congratulations',
                                    'annotation done!', QMessageBox.Ok)
            return
        # print message
        if '\\' in self.anno_list[self.index]:
            scene_id, instance = self.anno_list[self.index].split('\\')
        else:
            scene_id, instance = self.anno_list[self.index].split('/')
        lable = self.anno_list[self.index].split('_')[-2][5:]
        if lable in self.ins_label.keys():
            message = scene_id + '/' + instance + '\t' + self.ins_label[lable]
        else:
            message = scene_id + '/' + instance
        self.label_label.setText(message)

        self.pc = np.load(
            os.path.join(self.loadroot, self.anno_list[self.index]))
        pmin, pmax = get_bbox(self.pc)
        center = (pmin + pmax) / 2
        self.pc -= center
        self.rt = np.eye(4)
        self.rt[:3, 3] = -center[:3]
        self.viewer = pptk.viewer(self.pc)
        self.init_viewer()
        #self.btn_continue.setEnabled(False)
        self.btn_a.setEnabled(True)
        self.btn_d.setEnabled(True)
        self.btn_w.setEnabled(True)
        self.btn_s.setEnabled(True)
        self.btn_back.setEnabled(True)
        self.btn_nxt.setEnabled(True)
        self.checkBox_Sym.setEnabled(True)
        self.checkBox_del.setEnabled(True)
        self.btn_reset.setEnabled(True)
Ejemplo n.º 9
0
    def find(self, lng, lat, radius, count=10, filters=tuple(), sort_func=None):
        """ find will search for entity that located within circle of radius

            It uses bbox model to look for all entities
            If all entities must be inside the circle it will
            filter out items by applying further calculation

            Args:
                lng: longitude of geopoint
                lat: latitude of geopoint
                radius: distance from geopoint in meters
                count: number of items to return
                filters: list of filters implementing Filter
                sort_func: function(items) for sorting
            Returns:
                list of items

        """
        left, bottom, right, top = get_bbox(lng, lat, radius)
        # print left, bottom, right, top

        items = self.idx.intersection((left, bottom, right, top),
                                      objects=True)

        result = [item.object for item in items]

        if self.strict_radius:
            radius_filter = RadiusFilter(lng, lat, radius)
            result = radius_filter.apply(result)

        for filt in filters:
            result = filt.apply(result)

        if sort_func is not None:
            result = sort_func(result)

        return result[:count]
Ejemplo n.º 10
0
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, cfg.w)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, cfg.h)
    return cap


trackers = []
cap = source_capture(sys.argv[1])
while True:

    ret, frame = cap.read()
    bboxes = []
    if ret:

        image, pose_list = poses.inference(frame)
        for body in pose_list:
            bbox = utils.get_bbox(list(body.values()))
            bboxes.append((bbox, body))

        track_boxes = [tracker.bbox for tracker in trackers]
        matched, unmatched_trackers, unmatched_detections = utils.tracker_match(
            track_boxes, [b[0] for b in bboxes])

        for idx, jdx in matched:
            trackers[idx].set_bbox(bboxes[jdx][0])
            trackers[idx].update_pose(bboxes[jdx][1])

        for idx in unmatched_detections:
            try:
                trackers[idx].count += 1
            except:
                pass